Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import datetime
import time
import jwt
from urllib.parse import urlparse
from skyflow.error import SkyflowError
from skyflow.service_account.client.auth_client import AuthClient
from skyflow.utils.logger import log_info, log_error_log
from skyflow.utils import get_base_url, format_scope, SkyflowMessages
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
from skyflow.utils import is_valid_url


invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
Expand Down Expand Up @@ -78,7 +81,14 @@ def get_service_account_token(credentials, options, logger):
except:
log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code)


if not isinstance(token_uri, str) or not is_valid_url(token_uri):
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)

if options and "token_uri" in options:
token_uri = options["token_uri"]

signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger)
base_url = get_base_url(token_uri)
auth_client = AuthClient(base_url)
Expand All @@ -88,10 +98,17 @@ def get_service_account_token(credentials, options, logger):
if options and "role_ids" in options:
formatted_scope = format_scope(options.get("role_ids"))

response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
try:
response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
scope=formatted_scope)
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
except UnauthorizedError:
log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code)
except Exception:
log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code)
return response.access_token, response.token_type

def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
Expand All @@ -112,32 +129,41 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):


def get_signed_tokens(credentials_obj, options):
try:
expiry_time = int(time.time()) + options.get("time_to_live", 60)
prefix = "signed_token_"

if options and options.get("data_tokens"):
for token in options["data_tokens"]:
claims = {
"iss": "sdk",
"key": credentials_obj.get("keyID"),
"exp": expiry_time,
"sub": credentials_obj.get("clientID"),
"tok": token,
"iat": int(time.time()),
}

if "ctx" in options:
claims["ctx"] = options["ctx"]

private_key = credentials_obj.get("privateKey")
expiry_time = int(time.time()) + options.get("time_to_live", 60)
prefix = "signed_token_"

token_uri = credentials_obj.get("tokenURI")
if not isinstance(token_uri, str) or not is_valid_url(token_uri):
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value)
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)

if options and "token_uri" in options:
token_uri = options["token_uri"]


if options and options.get("data_tokens"):
for token in options["data_tokens"]:
claims = {
"iss": "sdk",
"key": credentials_obj.get("keyID"),
"exp": expiry_time,
"sub": credentials_obj.get("clientID"),
"tok": token,
"iat": int(time.time()),
}

if "ctx" in options:
claims["ctx"] = options["ctx"]

private_key = credentials_obj.get("privateKey")
try:
signed_jwt = jwt.encode(claims, private_key, algorithm="RS256")
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
return response_object
except Exception:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)

except Exception:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
return response_object


def generate_signed_data_tokens(credentials_file_path, options):
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..utils.enums import LogLevel, Env, TokenType
from ._skyflow_messages import SkyflowMessages
from ._version import SDK_VERSION
from ._helpers import get_base_url, format_scope
from ._helpers import get_base_url, format_scope, is_valid_url
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info
9 changes: 8 additions & 1 deletion skyflow/utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,11 @@ def get_base_url(url):
def format_scope(scopes):
if not scopes:
return None
return " ".join([f"role:{scope}" for scope in scopes])
return " ".join([f"role:{scope}" for scope in scopes])

def is_valid_url(url):
try:
result = urlparse(url)
return all([result.scheme in ("http", "https"), result.netloc])
except Exception:
return False
5 changes: 5 additions & 0 deletions skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ class Error(Enum):
MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID."
MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID."
MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI."
INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL."
JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials."
JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials."
FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents."
INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable."
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."

INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text."
INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities."
Expand Down Expand Up @@ -332,6 +335,8 @@ class ErrorLogs(Enum):
KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required."
TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required."
INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials."
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."


TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required."
Expand Down
19 changes: 12 additions & 7 deletions skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \
GetDetectRunRequest, Bleep, DeidentifyFileRequest
from skyflow.vault.detect._file_input import FileInput
from skyflow.utils._helpers import is_valid_url

valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"]
valid_connection_config_keys = ["connection_id", "connection_url", "credentials"]
Expand Down Expand Up @@ -138,6 +139,15 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id)
if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value,
invalid_input_error_code)

if "token_uri" in credentials:
token_uri = credentials.get("token_uri")
if (
token_uri is None
or not isinstance(token_uri, str)
or not is_valid_url(token_uri)
):
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)

def validate_log_level(logger, log_level):
if not isinstance(log_level, LogLevel):
Expand Down Expand Up @@ -202,10 +212,8 @@ def validate_update_vault_config(logger, config):
if "env" in config and config.get("env") not in Env:
raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code)

if "credentials" not in config:
raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code)

validate_credentials(logger, config.get("credentials"), "vault", vault_id)
if "credentials" in config and config.get("credentials"):
validate_credentials(logger, config.get("credentials"), "vault", vault_id)

return True

Expand Down Expand Up @@ -413,9 +421,6 @@ def validate_insert_request(logger, request):
if key is None or key == "":
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger)

if value is None or value == "":
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger)

if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()):
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code)
Expand Down
5 changes: 4 additions & 1 deletion skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from skyflow.error import SkyflowError
from skyflow.generated.rest.client import Skyflow
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
Expand Down Expand Up @@ -62,6 +63,8 @@ def get_bearer_token(self, credentials):
"role_ids": self.__config.get("roles"),
"ctx": self.__config.get("ctx")
}
if "token_uri" in credentials and credentials.get("token_uri"):
options["token_uri"] = credentials.get("token_uri")

if self.__bearer_token is None or self.__is_config_updated:
if 'path' in credentials:
Expand All @@ -85,7 +88,7 @@ def get_bearer_token(self, credentials):

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

return self.__bearer_token

Expand Down
71 changes: 70 additions & 1 deletion tests/service_account/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,73 @@ def test_generate_signed_data_tokens_from_creds_with_invalid_string(self):
credentials_string = '{'
with self.assertRaises(SkyflowError) as context:
result = generate_signed_data_tokens_from_creds(credentials_string, options)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)

@patch("skyflow.service_account._utils.AuthClient")
@patch("skyflow.service_account._utils.get_signed_jwt")
def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client):
creds = {
"privateKey": "private_key",
"clientID": "client_id",
"keyID": "key_id",
"tokenURI": "https://valid-url.com"
}
options = {"role_ids": ["role1", "role2"]}
mock_get_signed_jwt.return_value = "signed"
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token",
"token_type": "bearer"})
access_token, token_type = get_service_account_token(creds, options, None)
self.assertEqual(access_token, "token")
self.assertEqual(token_type, "bearer")
args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args
self.assertIn("scope", kwargs)
self.assertEqual(kwargs["scope"], "role:role1 role:role2")

@patch("skyflow.service_account._utils.AuthClient")
@patch("skyflow.service_account._utils.get_signed_jwt")
def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client):
creds = {
"privateKey": "private_key",
"clientID": "client_id",
"keyID": "key_id",
"tokenURI": "https://valid-url.com"
}
mock_get_signed_jwt.return_value = "signed"
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized")
with self.assertRaises(SkyflowError) as context:
get_service_account_token(creds, {}, None)
self.assertEqual(context.exception.message,
SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value)

@patch("skyflow.service_account._utils.AuthClient")
@patch("skyflow.service_account._utils.get_signed_jwt")
def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client):
creds = {
"privateKey": "private_key",
"clientID": "client_id",
"keyID": "key_id",
"tokenURI": "https://valid-url.com"
}
mock_get_signed_jwt.return_value = "signed"
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error")
with self.assertRaises(SkyflowError) as context:
get_service_account_token(creds, {}, None)
self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value)

@patch("jwt.encode", side_effect=Exception("jwt error"))
def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode):
creds = {
"privateKey": "private_key",
"clientID": "client_id",
"keyID": "key_id",
"tokenURI": "https://valid-url.com"
}
options = {"data_tokens": ["token1"]}
with self.assertRaises(SkyflowError) as context:
from skyflow.service_account._utils import get_signed_tokens
get_signed_tokens(creds, options)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)
27 changes: 25 additions & 2 deletions tests/utils/test__helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from skyflow.utils import get_base_url, format_scope
from skyflow.utils import get_base_url, format_scope, is_valid_url

VALID_URL = "https://example.com/path?query=1"
BASE_URL = "https://example.com"
Expand Down Expand Up @@ -35,4 +35,27 @@ def test_format_scope_single_scope(self):
def test_format_scope_special_characters(self):
scopes_with_special_chars = ["admin", "user:write", "read-only"]
expected_result = "role:admin role:user:write role:read-only"
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)

def test_is_valid_url_valid(self):
self.assertTrue(is_valid_url("https://example.com"))
self.assertTrue(is_valid_url("http://example.com/path"))

def test_is_valid_url_invalid(self):
self.assertFalse(is_valid_url("ftp://example.com"))
self.assertFalse(is_valid_url("example.com"))
self.assertFalse(is_valid_url("invalid-url"))
self.assertFalse(is_valid_url(""))

def test_is_valid_url_none(self):
self.assertFalse(is_valid_url(None))

def test_is_valid_url_no_scheme(self):
self.assertFalse(is_valid_url("www.example.com"))

def test_is_valid_url_exception(self):
class BadStr:
def __str__(self):
raise Exception("bad str")

self.assertFalse(is_valid_url(BadStr()))
Loading
Loading