Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions skyflow/client/skyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from skyflow.error import SkyflowError
from skyflow.utils import SkyflowMessages
from skyflow.utils.logger import log_info, Logger
from skyflow.utils.constants import OptionField
from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \
validate_update_connection_config, validate_credentials, validate_log_level
from skyflow.vault.client.client import VaultClient
Expand Down Expand Up @@ -30,7 +31,7 @@ def update_vault_config(self,config):
self.__builder.update_vault_config(config)

def get_vault_config(self, vault_id):
return self.__builder.get_vault_config(vault_id).get("vault_client").get_config()
return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config()

def add_connection_config(self, config):
self.__builder._Builder__add_connection_config(config)
Expand All @@ -45,7 +46,7 @@ def update_connection_config(self, config):
return self

def get_connection_config(self, connection_id):
return self.__builder.get_connection_config(connection_id).get("vault_client").get_config()
return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config()

def add_skyflow_credentials(self, credentials):
self.__builder._Builder__add_skyflow_credentials(credentials)
Expand All @@ -66,15 +67,15 @@ def update_log_level(self, log_level):

def vault(self, vault_id = None) -> Vault:
vault_config = self.__builder.get_vault_config(vault_id)
return vault_config.get("vault_controller")
return vault_config.get(OptionField.VAULT_CONTROLLER)

def connection(self, connection_id = None) -> Connection:
connection_config = self.__builder.get_connection_config(connection_id)
return connection_config.get("controller")
return connection_config.get(OptionField.CONTROLLER)

def detect(self, vault_id = None) -> Detect:
vault_config = self.__builder.get_vault_config(vault_id)
return vault_config.get("detect_controller")
return vault_config.get(OptionField.DETECT_CONTROLLER)

class Builder:
def __init__(self):
Expand All @@ -87,13 +88,13 @@ def __init__(self):
self.__logger = Logger(LogLevel.ERROR)

def add_vault_config(self, config):
vault_id = config.get("vault_id")
vault_id = config.get(OptionField.VAULT_ID)
if not isinstance(vault_id, str) or not vault_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_VAULT_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)
if vault_id in [vault.get("vault_id") for vault in self.__vault_list]:
if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]:
log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger)
raise SkyflowError(
SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id),
Expand All @@ -112,9 +113,9 @@ def remove_vault_config(self, vault_id):

def update_vault_config(self, config):
validate_update_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
vault_id = config.get(OptionField.VAULT_ID)
vault_config = self.__vault_configs[vault_id]
vault_config.get("vault_client").update_config(config)
vault_config.get(OptionField.VAULT_CLIENT).update_config(config)

def get_vault_config(self, vault_id):
if vault_id is None:
Expand All @@ -129,13 +130,13 @@ def get_vault_config(self, vault_id):


def add_connection_config(self, config):
connection_id = config.get("connection_id")
connection_id = config.get(OptionField.CONNECTION_ID)
if not isinstance(connection_id, str) or not connection_id:
raise SkyflowError(
SkyflowMessages.Error.INVALID_CONNECTION_ID.value,
SkyflowMessages.ErrorCodes.INVALID_INPUT.value
)
if connection_id in [connection.get("connection_id") for connection in self.__connection_list]:
if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]:
log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger)
raise SkyflowError(
SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id),
Expand All @@ -153,9 +154,9 @@ def remove_connection_config(self, connection_id):

def update_connection_config(self, config):
validate_update_connection_config(self.__logger, config)
connection_id = config['connection_id']
connection_id = config[OptionField.CONNECTION_ID]
connection_config = self.__connection_configs[connection_id]
connection_config.get("vault_client").update_config(config)
connection_config.get(OptionField.VAULT_CLIENT).update_config(config)

def get_connection_config(self, connection_id):
if connection_id is None:
Expand Down Expand Up @@ -183,32 +184,32 @@ def get_logger(self):

def __add_vault_config(self, config):
validate_vault_config(self.__logger, config)
vault_id = config.get("vault_id")
vault_id = config.get(OptionField.VAULT_ID)
vault_client = VaultClient(config)
self.__vault_configs[vault_id] = {
"vault_client": vault_client,
"vault_controller": Vault(vault_client),
"detect_controller": Detect(vault_client)
OptionField.VAULT_CLIENT: vault_client,
OptionField.VAULT_CONTROLLER: Vault(vault_client),
OptionField.DETECT_CONTROLLER: Detect(vault_client)
}
log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger)
log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger)
log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger)
log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger)

def __add_connection_config(self, config):
validate_connection_config(self.__logger, config)
connection_id = config.get("connection_id")
connection_id = config.get(OptionField.CONNECTION_ID)
vault_client = VaultClient(config)
self.__connection_configs[connection_id] = {
"vault_client": vault_client,
"controller": Connection(vault_client)
OptionField.VAULT_CLIENT: vault_client,
OptionField.CONTROLLER: Connection(vault_client)
}
log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger)
log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger)

def __update_vault_client_logger(self, log_level, logger):
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_logger(log_level,logger)
vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_logger(log_level,logger)
connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger)

def __set_log_level(self, log_level):
validate_log_level(self.__logger, log_level)
Expand All @@ -223,10 +224,10 @@ def __add_skyflow_credentials(self, credentials):
self.__skyflow_credentials = credentials
validate_credentials(self.__logger, credentials)
for vault_id, vault_config in self.__vault_configs.items():
vault_config.get("vault_client").set_common_skyflow_credentials(credentials)
vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials)

for connection_id, connection_config in self.__connection_configs.items():
connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials)
connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials)
def build(self):
validate_log_level(self.__logger, self.__log_level)
self.__logger.set_log_level(self.__log_level)
Expand Down
69 changes: 35 additions & 34 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skyflow.service_account.client.auth_client import AuthClient
from skyflow.utils.logger import log_info, log_error_log
from skyflow.utils import get_base_url, format_scope, SkyflowMessages
from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField


invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
Expand All @@ -17,8 +18,8 @@

try:
decoded = jwt.decode(
token, options={"verify_signature": False, "verify_aud": False})
if time.time() >= decoded['exp']:
token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False})
if time.time() >= decoded[JwtField.EXP]:
log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger)
log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value)
return True
Expand Down Expand Up @@ -59,22 +60,22 @@

def get_service_account_token(credentials, options, logger):
try:
private_key = credentials["privateKey"]
private_key = credentials[CredentialField.PRIVATE_KEY]
except:
log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger)
raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code)
try:
client_id = credentials["clientID"]
client_id = credentials[CredentialField.CLIENT_ID]
except:
log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code)
try:
key_id = credentials["keyID"]
key_id = credentials[CredentialField.KEY_ID]
except:
log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code)
try:
token_uri = credentials["tokenURI"]
token_uri = credentials[CredentialField.TOKEN_URI]
except:
log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger)
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code)
Expand All @@ -85,53 +86,53 @@
auth_api = auth_client.get_auth_api()

formatted_scope = None
if options and "role_ids" in options:
formatted_scope = format_scope(options.get("role_ids"))
if options and OptionField.ROLE_IDS in options:
formatted_scope = format_scope(options.get(OptionField.ROLE_IDS))

response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
grant_type=JWT.GRANT_TYPE_JWT_BEARER,
scope=formatted_scope)
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
return response.access_token, response.token_type

def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
payload = {
"iss": client_id,
"key": key_id,
"aud": token_uri,
"sub": client_id,
"exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60)
JwtField.ISS: client_id,
JwtField.KEY: key_id,
JwtField.AUD: token_uri,
JwtField.SUB: client_id,
JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60)
}
if options and "ctx" in options:
payload["ctx"] = options.get("ctx")
if options and JwtField.CTX in options:
payload[JwtField.CTX] = options.get(JwtField.CTX)
try:
return jwt.encode(payload=payload, key=private_key, algorithm="RS256")
return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256)
except Exception:
raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code)



def get_signed_tokens(credentials_obj, options):
try:
expiry_time = int(time.time()) + options.get("time_to_live", 60)
prefix = "signed_token_"
expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60)
prefix = JWT.SIGNED_TOKEN_PREFIX

if options and options.get("data_tokens"):
for token in options["data_tokens"]:
if options and options.get(OptionField.DATA_TOKENS):
for token in options[OptionField.DATA_TOKENS]:
claims = {
"iss": "sdk",
"key": credentials_obj.get("keyID"),
"exp": expiry_time,
"sub": credentials_obj.get("clientID"),
"tok": token,
"iat": int(time.time()),
JwtField.ISS: JWT.ISSUER_SDK,
JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID),
JwtField.EXP: expiry_time,
JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID),
JwtField.TOK: token,
JwtField.IAT: int(time.time()),
}

if "ctx" in options:
claims["ctx"] = options["ctx"]
if JwtField.CTX in options:
claims[JwtField.CTX] = options[JwtField.CTX]

private_key = credentials_obj.get("privateKey")
signed_jwt = jwt.encode(claims, private_key, algorithm="RS256")
private_key = credentials_obj.get(CredentialField.PRIVATE_KEY)
signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256)
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
return response_object
Expand Down Expand Up @@ -170,7 +171,7 @@

def get_signed_data_token_response_object(signed_token, actual_token):
response_object = {
"token": actual_token,
"signed_token": signed_token
ResponseField.TOKEN: actual_token,

Check failure

Code scanning / Semgrep OSS

Semgrep Finding: semgreprules.check-sensitive-info Error

Potential sensitive information found: TOKEN
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skyflow data token (non-sensitive, not an auth token)

ResponseField.SIGNED_TOKEN: signed_token
}
return response_object.get("token"), response_object.get("signed_token")
return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN)
3 changes: 3 additions & 0 deletions skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class Error(Enum):
RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON."
API_ERROR = f"{error_prefix} Server returned status code {{}}"

INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received."
UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred."

INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input."
INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum."
INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum."
Expand Down
Loading
Loading