diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 9f0d9dbf..0bfde34e 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -3,6 +3,7 @@ from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow.utils.logger import log_info, Logger +from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -30,7 +31,7 @@ def update_vault_config(self,config): self.__builder.update_vault_config(config) def get_vault_config(self, vault_id): - return self.__builder.get_vault_config(vault_id).get("vault_client").get_config() + return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config() def add_connection_config(self, config): self.__builder._Builder__add_connection_config(config) @@ -45,7 +46,7 @@ def update_connection_config(self, config): return self def get_connection_config(self, connection_id): - return self.__builder.get_connection_config(connection_id).get("vault_client").get_config() + return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config() def add_skyflow_credentials(self, credentials): self.__builder._Builder__add_skyflow_credentials(credentials) @@ -66,15 +67,15 @@ def update_log_level(self, log_level): def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("vault_controller") + return vault_config.get(OptionField.VAULT_CONTROLLER) def connection(self, connection_id = None) -> Connection: connection_config = self.__builder.get_connection_config(connection_id) - return connection_config.get("controller") + return connection_config.get(OptionField.CONTROLLER) def detect(self, vault_id = None) -> Detect: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("detect_controller") + return vault_config.get(OptionField.DETECT_CONTROLLER) class Builder: def __init__(self): @@ -87,13 +88,13 @@ def __init__(self): self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) if not isinstance(vault_id, str) or not vault_id: raise SkyflowError( SkyflowMessages.Error.INVALID_VAULT_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]: log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id), @@ -112,9 +113,9 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_config = self.__vault_configs[vault_id] - vault_config.get("vault_client").update_config(config) + vault_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_vault_config(self, vault_id): if vault_id is None: @@ -129,13 +130,13 @@ def get_vault_config(self, vault_id): def add_connection_config(self, config): - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) if not isinstance(connection_id, str) or not connection_id: raise SkyflowError( SkyflowMessages.Error.INVALID_CONNECTION_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]: log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id), @@ -153,9 +154,9 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) - connection_id = config['connection_id'] + connection_id = config[OptionField.CONNECTION_ID] connection_config = self.__connection_configs[connection_id] - connection_config.get("vault_client").update_config(config) + connection_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_connection_config(self, connection_id): if connection_id is None: @@ -183,32 +184,32 @@ def get_logger(self): def __add_vault_config(self, config): validate_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_client = VaultClient(config) self.__vault_configs[vault_id] = { - "vault_client": vault_client, - "vault_controller": Vault(vault_client), - "detect_controller": Detect(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.VAULT_CONTROLLER: Vault(vault_client), + OptionField.DETECT_CONTROLLER: Detect(vault_client) } - log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) - log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) + log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) + log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) def __add_connection_config(self, config): validate_connection_config(self.__logger, config) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) vault_client = VaultClient(config) self.__connection_configs[connection_id] = { - "vault_client": vault_client, - "controller": Connection(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.CONTROLLER: Connection(vault_client) } - log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger) + log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger) def __update_vault_client_logger(self, log_level, logger): for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_logger(log_level,logger) + vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_logger(log_level,logger) + connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) @@ -223,10 +224,10 @@ def __add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials validate_credentials(self.__logger, credentials) for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_common_skyflow_credentials(credentials) + vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials) + connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials) def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..3f21ba21 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -6,6 +6,7 @@ from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -17,8 +18,8 @@ def is_expired(token, logger = None): try: decoded = jwt.decode( - token, options={"verify_signature": False, "verify_aud": False}) - if time.time() >= decoded['exp']: + token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False}) + if time.time() >= decoded[JwtField.EXP]: log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger) log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -59,22 +60,22 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) def get_service_account_token(credentials, options, logger): try: - private_key = credentials["privateKey"] + private_key = credentials[CredentialField.PRIVATE_KEY] except: log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: - client_id = credentials["clientID"] + client_id = credentials[CredentialField.CLIENT_ID] except: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: - key_id = credentials["keyID"] + key_id = credentials[CredentialField.KEY_ID] except: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: - token_uri = credentials["tokenURI"] + token_uri = credentials[CredentialField.TOKEN_URI] except: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -85,27 +86,27 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if options and "role_ids" in options: - formatted_scope = format_scope(options.get("role_ids")) + if options and OptionField.ROLE_IDS in options: + formatted_scope = format_scope(options.get(OptionField.ROLE_IDS)) response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): payload = { - "iss": client_id, - "key": key_id, - "aud": token_uri, - "sub": client_id, - "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) + JwtField.ISS: client_id, + JwtField.KEY: key_id, + JwtField.AUD: token_uri, + JwtField.SUB: client_id, + JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and "ctx" in options: - payload["ctx"] = options.get("ctx") + if options and JwtField.CTX in options: + payload[JwtField.CTX] = options.get(JwtField.CTX) try: - return jwt.encode(payload=payload, key=private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code) @@ -113,25 +114,25 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" + expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) + prefix = JWT.SIGNED_TOKEN_PREFIX - if options and options.get("data_tokens"): - for token in options["data_tokens"]: + if options and options.get(OptionField.DATA_TOKENS): + for token in options[OptionField.DATA_TOKENS]: claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), + JwtField.ISS: JWT.ISSUER_SDK, + JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID), + JwtField.EXP: expiry_time, + JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID), + JwtField.TOK: token, + JwtField.IAT: int(time.time()), } - if "ctx" in options: - claims["ctx"] = options["ctx"] + if JwtField.CTX in options: + claims[JwtField.CTX] = options[JwtField.CTX] - private_key = credentials_obj.get("privateKey") - signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") + private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) + signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) return response_object @@ -170,7 +171,7 @@ def generate_signed_data_tokens_from_creds(credentials, options): def get_signed_data_token_response_object(signed_token, actual_token): response_object = { - "token": actual_token, - "signed_token": signed_token + ResponseField.TOKEN: actual_token, + ResponseField.SIGNED_TOKEN: signed_token } - return response_object.get("token"), response_object.get("signed_token") + return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..1954ed4d 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -71,6 +71,9 @@ class Error(Enum): RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." API_ERROR = f"{error_prefix} Server returned status code {{}}" + INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received." + UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred." + INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input." INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum." INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..c6f294cd 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -20,7 +20,8 @@ from skyflow.vault.detect import DeidentifyTextResponse, ReidentifyTextResponse from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION -from .constants import PROTOCOL +from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, + EncodingType, BooleanString, ResponseField, CredentialField) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -44,7 +45,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg try: env_creds = env_skyflow_credentials.replace('\n', '\\n') return { - 'credentials_string': env_creds + CredentialField.CREDENTIALS_STRING: env_creds } except json.JSONDecodeError: raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) @@ -52,7 +53,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') @@ -113,13 +114,13 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON.value + if not HttpHeader.CONTENT_TYPE.lower() in header: + header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value try: if isinstance(request.body, dict): json_data, files = get_data_from_content_type( - request.body, header["content-type"] + request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE] ) else: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) @@ -216,30 +217,30 @@ def parse_insert_response(api_response, continue_on_error): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) inserted_fields = [] errors = [] insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response['Status'] == 200: - body = response['Body'] - if 'records' in body: - for record in body['records']: + if response[ResponseField.STATUS] == 200: + body = response[ResponseField.BODY] + if ResponseField.RECORDS in body: + for record in body[ResponseField.RECORDS]: inserted_field = { - 'skyflow_id': record['skyflow_id'], - 'request_index': idx + ResponseField.SKYFLOW_ID: record[ResponseField.SKYFLOW_ID], + ResponseField.REQUEST_INDEX: idx } - if 'tokens' in record: - inserted_field.update(record['tokens']) + if ResponseField.TOKENS in record: + inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response['Status'] == 400: + elif response[ResponseField.STATUS] == 400: error = { - 'request_index': idx, - 'request_id': request_id, - 'error': response['Body']['error'], - 'http_code': response['Status'], + ResponseField.REQUEST_INDEX: idx, + ResponseField.REQUEST_ID: request_id, + ResponseField.ERROR: response[ResponseField.BODY][ResponseField.ERROR], + ResponseField.HTTP_CODE: response[ResponseField.STATUS], } errors.append(error) @@ -248,7 +249,7 @@ def parse_insert_response(api_response, continue_on_error): else: for record in api_response_data.records: field_data = { - 'skyflow_id': record.skyflow_id + ResponseField.SKYFLOW_ID: record.skyflow_id } if record.tokens: @@ -263,7 +264,7 @@ def parse_insert_response(api_response, continue_on_error): def parse_update_record_response(api_response: V1UpdateRecordResponse): update_response = UpdateResponse() updated_field = dict() - updated_field['skyflow_id'] = api_response.skyflow_id + updated_field[ResponseField.SKYFLOW_ID] = api_response.skyflow_id if api_response.tokens is not None: updated_field.update(api_response.tokens) @@ -293,23 +294,23 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) detokenized_fields = [] errors = [] for record in api_response_data.records: if record.error: errors.append({ - "token": record.token, - "error": record.error, - "request_id": request_id + ResponseField.TOKEN: record.token, + ResponseField.ERROR: record.error, + ResponseField.REQUEST_ID: request_id }) else: value_type = record.value_type if record.value_type else None detokenized_fields.append({ - "token": record.token, - "value": record.value, - "type": value_type + ResponseField.TOKEN: record.token, + ResponseField.VALUE: record.value, + ResponseField.TYPE: value_type }) detokenized_fields = detokenized_fields @@ -322,7 +323,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): def parse_tokenize_response(api_response: V1TokenizeResponse): tokenize_response = TokenizeResponse() - tokenized_fields = [{"token": record.token} for record in api_response.records] + tokenized_fields = [{ResponseField.TOKEN: record.token} for record in api_response.records] tokenize_response.tokenized_fields = tokenized_fields @@ -334,7 +335,7 @@ def parse_query_response(api_response: V1GetQueryResponse): for record in api_response.records: field_object = { **record.fields, - "tokenized_data": {} + ResponseField.TOKENIZED_DATA: {} } fields.append(field_object) query_response.fields = fields @@ -344,14 +345,14 @@ def parse_invoke_connection_response(api_response: requests.Response): status_code = api_response.status_code content = api_response.content if isinstance(content, bytes): - content = content.decode('utf-8') + content = content.decode(EncodingType.UTF_8) try: api_response.raise_for_status() try: data = json.loads(content) metadata = {} - if 'x-request-id' in api_response.headers: - metadata['request_id'] = api_response.headers['x-request-id'] + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: @@ -360,19 +361,19 @@ def parse_invoke_connection_response(api_response: requests.Response): message = SkyflowMessages.Error.API_ERROR.value.format(status_code) try: error_response = json.loads(content) - request_id = api_response.headers['x-request-id'] - error_from_client = api_response.headers.get('error-from-client') - - status_code = error_response.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = error_response.get('error', {}).get('http_status') - grpc_code = error_response.get('error', {}).get('grpc_code') - details = error_response.get('error', {}).get('details') - message = error_response.get('error', {}).get('message', "An unknown error occurred.") + request_id = api_response.headers[HttpHeader.X_REQUEST_ID] + error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) + + status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) + message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) if error_from_client is not None: if details is None: details = [] - error_from_client_bool = error_from_client.lower() == 'true' - details.append({'error_from_client': error_from_client_bool}) + error_from_client_bool = error_from_client.lower() == BooleanString.TRUE + details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) except json.JSONDecodeError: @@ -399,14 +400,14 @@ def handle_exception(error, logger): if (isinstance(error, httpx.ConnectError)): handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) - request_id = error.headers.get('x-request-id', 'unknown-request-id') - content_type = error.headers.get('content-type') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') + content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body if content_type: - if 'application/json' in content_type: + if ContentTypeConstants.APPLICATION_JSON in content_type: handle_json_error(error, data, request_id, logger) - elif 'text/plain' in content_type: + elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: handle_generic_error(error, request_id, logger) @@ -421,15 +422,15 @@ def handle_json_error(err, data, request_id, logger): description = data.dict() else: description = json.loads(data) - status_code = description.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = description.get('error', {}).get('http_status') - grpc_code = description.get('error', {}).get('grpc_code') - details = description.get('error', {}).get('details', []) + status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) - description_message = description.get('error', {}).get('message', "An unknown error occurred.") + description_message = description.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) except json.JSONDecodeError: - log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger = logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index ef20faf8..30cb124d 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -2,3 +2,166 @@ PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +class SKYFLOW: + SKYFLOW_ID = 'skyflowId' + X_SKYFLOW_AUTHORIZATION = 'x-skyflow-authorization' + + +class HttpHeader: + CONTENT_TYPE = 'Content-Type' + CONTENT_TYPE_LOWERCASE = 'content-type' + X_REQUEST_ID = 'x-request-id' + ERROR_FROM_CLIENT = 'error-from-client' + AUTHORIZATION = 'Authorization' + + +class HttpStatusCode: + OK = 200 + BAD_REQUEST = 400 + INTERNAL_SERVER_ERROR = 500 + + +class ContentType: + APPLICATION_JSON = 'application/json' + APPLICATION_X_WWW_FORM_URLENCODED = 'application/x-www-form-urlencoded' + TEXT_PLAIN = 'text/plain' + + +class DetectStatus: + IN_PROGRESS = 'IN_PROGRESS' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + UNKNOWN = 'UNKNOWN' + + +class FileExtension: + JSON = 'json' + MP3 = 'mp3' + WAV = 'wav' + PDF = 'pdf' + TXT = 'txt' + DOC = 'doc' + DOCX = 'docx' + JPG = 'jpg' + JPEG = 'jpeg' + PNG = 'png' + BMP = 'bmp' + TIF = 'tif' + TIFF = 'tiff' + PPT = 'ppt' + PPTX = 'pptx' + CSV = 'csv' + XLS = 'xls' + XLSX = 'xlsx' + XML = 'xml' + + +class FileProcessing: + PROCESSED_PREFIX = 'processed-' + DEIDENTIFIED_PREFIX = 'deidentified.' + ENTITIES = 'entities' + + +class EncodingType: + UTF8 = 'utf8' + UTF_8 = 'utf-8' + BASE64 = 'base64' + BINARY = 'binary' + + +class JWT: + ALGORITHM_RS256 = 'RS256' + GRANT_TYPE_JWT_BEARER = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + ISSUER_SDK = 'sdk' + SIGNED_TOKEN_PREFIX = 'signed_token_' + ROLE_PREFIX = 'role:' + + +class ApiKey: + SKY_PREFIX = 'sky-' + LENGTH = 42 + + +class UrlProtocol: + HTTPS = 'https' + HTTP = 'http' + + +class BooleanString: + TRUE = 'true' + FALSE = 'false' + + +class ResponseField: + STATUS = 'Status' + BODY = 'Body' + RECORDS = 'records' + TOKENS = 'tokens' + ERROR = 'error' + SKYFLOW_ID = 'skyflow_id' + REQUEST_INDEX = 'request_index' + REQUEST_ID = 'request_id' + HTTP_CODE = 'http_code' + HTTP_STATUS = 'http_status' + GRPC_CODE = 'grpc_code' + DETAILS = 'details' + MESSAGE = 'message' + ERROR_FROM_CLIENT = 'error_from_client' + TOKEN = 'token' + VALUE = 'value' + TYPE = 'type' + TOKENIZED_DATA = 'tokenized_data' + SIGNED_TOKEN = 'signed_token' + + +class CredentialField: + PRIVATE_KEY = 'privateKey' + CLIENT_ID = 'clientID' + KEY_ID = 'keyID' + TOKEN_URI = 'tokenURI' + CREDENTIALS_STRING = 'credentials_string' + API_KEY = 'api_key' + TOKEN = 'token' + PATH = 'path' + + +class JwtField: + ISS = 'iss' + KEY = 'key' + AUD = 'aud' + SUB = 'sub' + EXP = 'exp' + CTX = 'ctx' + TOK = 'tok' + IAT = 'iat' + + +class OptionField: + ROLE_IDS = 'role_ids' + DATA_TOKENS = 'data_tokens' + TIME_TO_LIVE = 'time_to_live' + ROLES = 'roles' + CTX = 'ctx' + VAULT_ID = 'vault_id' + CONNECTION_ID = 'connection_id' + CONNECTION_URL = 'connection_url' + VAULT_CLIENT = 'vault_client' + VAULT_CONTROLLER = 'vault_controller' + DETECT_CONTROLLER = 'detect_controller' + CONTROLLER = 'controller' + VERIFY_SIGNATURE = 'verify_signature' + VERIFY_AUD = 'verify_aud' + + +class ConfigField: + CREDENTIALS = 'credentials' + CLUSTER_ID = 'cluster_id' + ENV = 'env' + VAULT_ID = 'vault_id' + + +class RequestParameter: + VALUE = 'value' + COLUMN_GROUP = 'column_group' + REDACTION = 'redaction' + diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index fdb11ea9..3fff980b 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,5 +1,6 @@ from ..enums import LogLevel from . import Logger +from ..constants import ResponseField def log_info(message, logger = None): @@ -18,17 +19,17 @@ def log_error(message, http_code, request_id=None, grpc_code=None, http_status=N logger = Logger(LogLevel.ERROR) log_data = { - 'http_code': http_code, - 'message': message + ResponseField.HTTP_CODE: http_code, + ResponseField.MESSAGE: message } if grpc_code is not None: - log_data['grpc_code'] = grpc_code + log_data[ResponseField.GRPC_CODE] = grpc_code if http_status is not None: - log_data['http_status'] = http_status + log_data[ResponseField.HTTP_STATUS] = http_status if request_id is not None: - log_data['request_id'] = request_id + log_data[ResponseField.REQUEST_ID] = request_id if details is not None: - log_data['details'] = details + log_data[ResponseField.DETAILS] = details logger.error(log_data) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..779fdfcc 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,6 +6,7 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages +from skyflow.utils.constants import ApiKey, ResponseField from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest @@ -50,11 +51,11 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if not api_key.startswith('sky-'): + if not api_key.startswith(ApiKey.SKY_PREFIX): log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger=logger) return False - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False @@ -582,10 +583,10 @@ def validate_get_request(logger, request): def validate_update_request(logger, request): skyflow_id = "" - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} try: - skyflow_id = request.data.get("skyflow_id") + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) except Exception: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..2d77330e 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -2,6 +2,7 @@ from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages from skyflow.utils.logger import log_info +from skyflow.utils.constants import OptionField, CredentialField, ConfigField class VaultClient: @@ -23,11 +24,11 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) + credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), + vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), logger = self.__logger) self.initialize_api_client(vault_url, token) @@ -50,29 +51,29 @@ def get_detect_file_api(self): return self.__api_client.files def get_vault_id(self): - return self.__config.get("vault_id") + return self.__config.get(ConfigField.VAULT_ID) def get_bearer_token(self, credentials): - if 'api_key' in credentials: - return credentials.get('api_key') - elif 'token' in credentials: - return credentials.get("token") + if CredentialField.API_KEY in credentials: + return credentials.get(CredentialField.API_KEY) + elif CredentialField.TOKEN in credentials: + return credentials.get(CredentialField.TOKEN) options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") + OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), + OptionField.CTX: self.__config.get(OptionField.CTX) } if self.__bearer_token is None or self.__is_config_updated: - if 'path' in credentials: - path = credentials.get("path") + if CredentialField.PATH in credentials: + path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( path, options, self.__logger ) else: - credentials_string = credentials.get('credentials_string') + credentials_string = credentials.get(CredentialField.CREDENTIALS_STRING) log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, self.__logger) self.__bearer_token, _ = generate_bearer_token_from_creds( credentials_string, diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 81c6ea10..83b0ffbd 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,6 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW class Connection: @@ -23,9 +24,9 @@ def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token + invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token - invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) + invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 44ef2540..4f2f50f2 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,7 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -64,7 +65,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): while True: response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data status = response.status - if status == 'IN_PROGRESS': + if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') else: @@ -76,7 +77,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): wait_time = next_wait_time current_wait_time = next_wait_time time.sleep(wait_time) - elif status == 'SUCCESS' or status == 'FAILED': + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: raise e @@ -88,7 +89,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o if not os.path.exists(output_directory): return - deidentify_file_prefix = "processed-" + deidentify_file_prefix = FileProcessing.PROCESSED_PREFIX output_list = response.output base_original_filename = os.path.basename(original_file_name) @@ -159,7 +160,7 @@ def output_to_dict_list(output): output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == "entities"] + entities = [o for o in output_list if o.get("type") == FileProcessing.ENTITIES] base64_string = first_output.get("file", None) extension = first_output.get("extension", None) @@ -167,14 +168,14 @@ def output_to_dict_list(output): if base64_string is not None: file_bytes = base64.b64decode(base64_string) file_obj = io.BytesIO(file_bytes) - file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else "processed_file" else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", "UNKNOWN"), + type=first_output.get("type", DetectStatus.UNKNOWN), extension=extension, word_count=word_count, char_count=char_count, @@ -282,11 +283,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_name = getattr(file_obj, 'name', None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) try: - if file_extension == 'txt': - req_file = FileDataDeidentifyText(base_64=base64_string, data_format="txt") + if file_extension == FileExtension.TXT: + req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { 'vault_id': self.__vault_client.get_vault_id(), @@ -299,7 +300,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['mp3', 'wav']: + elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio api_kwargs = { @@ -319,7 +320,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension == 'pdf': + elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { @@ -334,7 +335,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: + elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { @@ -350,7 +351,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['ppt', 'pptx']: + elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { @@ -363,7 +364,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['csv', 'xls', 'xlsx']: + elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { @@ -376,7 +377,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['doc', 'docx']: + elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { @@ -389,7 +390,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['json', 'xml']: + elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { @@ -423,7 +424,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, 'run_id', None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == 'SUCCESS': + if request.output_directory and processed_response.status == DetectStatus.SUCCESS: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -450,7 +451,7 @@ def get_detect_run(self, request: GetDetectRunRequest): vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers() ) - if response.data.status == 'IN_PROGRESS': + if response.data.status == DetectStatus.IN_PROGRESS: parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7cc9ec77..a5cd94fd 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -125,7 +125,7 @@ def update(self, request: UpdateRequest): validate_update_request(self.__vault_client.get_logger(), request) log_info(SkyflowMessages.Info.UPDATE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} record = V1FieldRecords(fields=field, tokens = request.tokens) records_api = self.__vault_client.get_records_api() @@ -134,7 +134,7 @@ def update(self, request: UpdateRequest): api_response = records_api.record_service_update_record( self.__vault_client.get_vault_id(), request.table, - id=request.data.get("skyflow_id"), + id=request.data.get(ResponseField.SKYFLOW_ID), record=record, tokenization=request.return_tokens, byot=request.token_mode.value, @@ -225,8 +225,8 @@ def detokenize(self, request: DetokenizeRequest): self.__initialize() tokens_list = [ V1DetokenizeRecordRequest( - token=item.get('token'), - redaction=item.get('redaction', RedactionType.DEFAULT) + token=item.get(ResponseField.TOKEN), + redaction=item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] @@ -253,7 +253,7 @@ def tokenize(self, request: TokenizeRequest): self.__initialize() records_list = [ - V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) + V1TokenizeRecordRequest(value=item[RequestParameter.VALUE], column_group=item[RequestParameter.COLUMN_GROUP]) for item in request.values ] tokens_api = self.__vault_client.get_tokens_api()