diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..a6044af2 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -2,10 +2,13 @@ import datetime import time import jwt +from urllib.parse import urlparse from skyflow.error import SkyflowError from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError +from skyflow.utils import is_valid_url invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -78,7 +81,14 @@ def get_service_account_token(credentials, options, logger): except: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) - + + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and "token_uri" in options: + token_uri = options["token_uri"] + signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) auth_client = AuthClient(base_url) @@ -88,10 +98,17 @@ def get_service_account_token(credentials, options, logger): if options and "role_ids" in options: formatted_scope = format_scope(options.get("role_ids")) - response = auth_api.authentication_service_get_auth_token(assertion = signed_token, + try: + response = auth_api.authentication_service_get_auth_token(assertion = signed_token, grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", scope=formatted_scope) - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + except UnauthorizedError: + log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): @@ -112,32 +129,41 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): - try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" - - if options and options.get("data_tokens"): - for token in options["data_tokens"]: - claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), - } - - if "ctx" in options: - claims["ctx"] = options["ctx"] - - private_key = credentials_obj.get("privateKey") + expiry_time = int(time.time()) + options.get("time_to_live", 60) + prefix = "signed_token_" + + token_uri = credentials_obj.get("tokenURI") + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and "token_uri" in options: + token_uri = options["token_uri"] + + + if options and options.get("data_tokens"): + for token in options["data_tokens"]: + claims = { + "iss": "sdk", + "key": credentials_obj.get("keyID"), + "exp": expiry_time, + "sub": credentials_obj.get("clientID"), + "tok": token, + "iat": int(time.time()), + } + + if "ctx" in options: + claims["ctx"] = options["ctx"] + + private_key = credentials_obj.get("privateKey") + try: signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) - log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) + log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) + return response_object def generate_signed_data_tokens(credentials_file_path, options): diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index f2788b11..664cf65d 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -1,5 +1,5 @@ from ..utils.enums import LogLevel, Env, TokenType from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION -from ._helpers import get_base_url, format_scope +from ._helpers import get_base_url, format_scope, is_valid_url from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 97eecabc..090f3a2b 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -8,4 +8,11 @@ def get_base_url(url): def format_scope(scopes): if not scopes: return None - return " ".join([f"role:{scope}" for scope in scopes]) \ No newline at end of file + return " ".join([f"role:{scope}" for scope in scopes]) + +def is_valid_url(url): + try: + result = urlparse(url) + return all([result.scheme in ("http", "https"), result.netloc]) + except Exception: + return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..99329978 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -153,10 +153,13 @@ class Error(Enum): MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID." MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID." MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI." + INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL." JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials." JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials." FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents." INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text." INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities." @@ -332,6 +335,8 @@ class ErrorLogs(Enum): KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required." TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required." INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required." diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..611efdae 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -10,6 +10,7 @@ from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput +from skyflow.utils._helpers import is_valid_url valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] @@ -138,6 +139,15 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) + + if "token_uri" in credentials: + token_uri = credentials.get("token_uri") + if ( + token_uri is None + or not isinstance(token_uri, str) + or not is_valid_url(token_uri) + ): + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): @@ -202,10 +212,8 @@ def validate_update_vault_config(logger, config): if "env" in config and config.get("env") not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) - - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if "credentials" in config and config.get("credentials"): + validate_credentials(logger, config.get("credentials"), "vault", vault_id) return True @@ -413,9 +421,6 @@ def validate_insert_request(logger, request): if key is None or key == "": log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) - if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..689fb6e9 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,3 +1,4 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages @@ -62,6 +63,8 @@ def get_bearer_token(self, credentials): "role_ids": self.__config.get("roles"), "ctx": self.__config.get("ctx") } + if "token_uri" in credentials and credentials.get("token_uri"): + options["token_uri"] = credentials.get("token_uri") if self.__bearer_token is None or self.__is_config_updated: if 'path' in credentials: @@ -85,7 +88,7 @@ def get_bearer_token(self, credentials): if is_expired(self.__bearer_token): self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) return self.__bearer_token diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index 7ffb36df..ca82527a 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -143,4 +143,73 @@ def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): credentials_string = '{' with self.assertRaises(SkyflowError) as context: result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) \ No newline at end of file + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + options = {"role_ids": ["role1", "role2"]} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token", + "token_type": "bearer"}) + access_token, token_type = get_service_account_token(creds, options, None) + self.assertEqual(access_token, "token") + self.assertEqual(token_type, "bearer") + args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args + self.assertIn("scope", kwargs) + self.assertEqual(kwargs["scope"], "role:role1 role:role2") + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + + @patch("jwt.encode", side_effect=Exception("jwt error")) + def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + options = {"data_tokens": ["token1"]} + with self.assertRaises(SkyflowError) as context: + from skyflow.service_account._utils import get_signed_tokens + get_signed_tokens(creds, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) \ No newline at end of file diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 8b55abf3..6758b62e 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -1,5 +1,5 @@ import unittest -from skyflow.utils import get_base_url, format_scope +from skyflow.utils import get_base_url, format_scope, is_valid_url VALID_URL = "https://example.com/path?query=1" BASE_URL = "https://example.com" @@ -35,4 +35,27 @@ def test_format_scope_single_scope(self): def test_format_scope_special_characters(self): scopes_with_special_chars = ["admin", "user:write", "read-only"] expected_result = "role:admin role:user:write role:read-only" - self.assertEqual(format_scope(scopes_with_special_chars), expected_result) \ No newline at end of file + self.assertEqual(format_scope(scopes_with_special_chars), expected_result) + + def test_is_valid_url_valid(self): + self.assertTrue(is_valid_url("https://example.com")) + self.assertTrue(is_valid_url("http://example.com/path")) + + def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("ftp://example.com")) + self.assertFalse(is_valid_url("example.com")) + self.assertFalse(is_valid_url("invalid-url")) + self.assertFalse(is_valid_url("")) + + def test_is_valid_url_none(self): + self.assertFalse(is_valid_url(None)) + + def test_is_valid_url_no_scheme(self): + self.assertFalse(is_valid_url("www.example.com")) + + def test_is_valid_url_exception(self): + class BadStr: + def __str__(self): + raise Exception("bad str") + + self.assertFalse(is_valid_url(BadStr())) \ No newline at end of file diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..55d8c00e 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,13 +1,15 @@ import unittest from unittest.mock import patch, Mock import os -import json from unittest.mock import MagicMock from urllib.parse import quote +import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse +from skyflow.service_account import generate_bearer_token, generate_signed_data_tokens, \ + generate_signed_data_tokens_from_creds, generate_bearer_token_from_creds from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ @@ -597,3 +599,189 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + def test_generate_bearer_token_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 # invalid type + } + + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + # Patch AuthClient and jwt.encode to avoid real HTTP and signing + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), + {"access_token": "token", + "token_type": "bearer"}) + generate_bearer_token(tmp.name, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), + {"access_token": "token", + "token_type": "bearer"}) + generate_bearer_token_from_creds(creds_str, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_signed_data_tokens_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens(tmp.name, options) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], "token1") + self.assertEqual(result[1], "signed_token_signed") + + def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens_from_creds(creds_str, options) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], "token1") + self.assertEqual(result[1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..5c3bb450 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -205,15 +205,6 @@ def test_validate_update_vault_config_valid(self): } self.assertTrue(validate_update_vault_config(self.logger, config)) - def test_validate_update_vault_config_missing_credentials(self): - config = { - "vault_id": "vault123", - "cluster_id": "cluster123" - } - with self.assertRaises(SkyflowError) as context: - validate_update_vault_config(self.logger, config) - self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123")) - def test_validate_update_vault_config_invalid_cluster_id(self): config = { "vault_id": "vault123", @@ -1044,3 +1035,69 @@ def test_validate_detokenize_request_invalid_redaction_type(self): with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + def test_validate_credentials_with_valid_token_uri(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + } + # Should not raise + validate_credentials(self.logger, credentials) + + def test_validate_credentials_with_invalid_token_uri_type(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 # Not a string + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_credentials_with_invalid_token_uri_url(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_valid_token_uri(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + }, + "env": Env.DEV + } + # Should not raise + self.assertTrue(validate_update_vault_config(self.logger, config)) + + def test_validate_update_vault_config_with_invalid_token_uri_type(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_invalid_token_uri_url(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..b4d6ec42 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -97,4 +97,24 @@ def test_get_log_level(self): def test_get_logger(self): mock_logger = MagicMock() self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + def test_get_bearer_token_with_token(self): + credentials = {"token": "dummy_token"} + token = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token, "dummy_token") + + def test_get_bearer_token_with_token_uri_in_credentials(self): + credentials = { + "path": "dummy_path", + "token_uri": "https://valid-url.com" + } + with patch("skyflow.vault.client.client.generate_bearer_token") as mock_generate_bearer_token, \ + patch("skyflow.vault.client.client.is_expired", return_value=False): + mock_generate_bearer_token.return_value = ("bearer_token", "bearer") + token = self.vault_client.get_bearer_token(credentials) + mock_generate_bearer_token.assert_called_once() + args, kwargs = mock_generate_bearer_token.call_args + self.assertIn("token_uri", args[1]) + self.assertEqual(args[1]["token_uri"], "https://valid-url.com") + self.assertEqual(token, "bearer_token") \ No newline at end of file