Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 194 additions & 78 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand All @@ -31,8 +32,38 @@ func (e *TokenExpiredError) Error() string {
return e.Message
}

type credentialType string

const (
credTypeClientSecret credentialType = "client_secret"
credTypeClientCertificate credentialType = "client_certificate"
credTypeCLI credentialType = "cli"
credTypeDevCLI credentialType = "dev_cli"
credTypeAzDOPipelines credentialType = "azdo_pipelines"
credTypeOIDC credentialType = "oidc"
credTypeUserManagedIdentity credentialType = "user_managed_identity"
credTypeSystemManagedIdentity credentialType = "system_managed_identity"
)

type credentialHolder struct {
credential azcore.TokenCredential
once sync.Once
err error
}

type cachedToken struct {
token string
expiresOn time.Time
}

type Auth struct {
config *config.ProviderConfig

credentials map[credentialType]*credentialHolder
credentialsMutex sync.RWMutex

cliTokens map[string]*cachedToken
cliTokensMutex sync.RWMutex
}

type OidcCredential struct {
Expand All @@ -55,88 +86,169 @@ type OidcCredentialOptions struct {

func NewAuthBase(configValue *config.ProviderConfig) *Auth {
return &Auth{
config: configValue,
config: configValue,
credentials: make(map[credentialType]*credentialHolder),
cliTokens: make(map[string]*cachedToken),
}
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cert, key, err := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if err != nil {
return "", time.Time{}, err
func (client *Auth) getOrCreateCredential(ctx context.Context, credType credentialType, factory func() (azcore.TokenCredential, error)) (azcore.TokenCredential, error) {
client.credentialsMutex.RLock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about token expiry? When do we invalidate the cache?

holder, exists := client.credentials[credType]
client.credentialsMutex.RUnlock()

if !exists {
client.credentialsMutex.Lock()
holder, exists = client.credentials[credType]
if !exists {
holder = &credentialHolder{}
client.credentials[credType] = holder
tflog.Debug(ctx, fmt.Sprintf("Created credential holder for type: %s", credType))
}
client.credentialsMutex.Unlock()
}

azureCertCredentials, err := azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
holder.once.Do(func() {
tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType))
holder.credential, holder.err = factory()
if holder.err != nil {
tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, holder.err.Error()))
} else {
tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType))
}
})

return holder.credential, holder.err
}

func (client *Auth) getCachedCliToken(cacheKey string) (string, time.Time, bool) {
client.cliTokensMutex.RLock()
defer client.cliTokensMutex.RUnlock()

cached, exists := client.cliTokens[cacheKey]
if !exists {
return "", time.Time{}, false
}

if time.Now().Add(5 * time.Minute).Before(cached.expiresOn) {
return cached.token, cached.expiresOn, true
}

return "", time.Time{}, false
}

func (client *Auth) setCachedCliToken(cacheKey string, token string, expiresOn time.Time) {
client.cliTokensMutex.Lock()
defer client.cliTokensMutex.Unlock()

client.cliTokens[cacheKey] = &cachedToken{
token: token,
expiresOn: expiresOn,
}
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cred, err := client.getOrCreateCredential(ctx, credTypeClientCertificate, func() (azcore.TokenCredential, error) {
cert, key, certErr := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if certErr != nil {
return nil, certErr
}

return azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}
accessToken, err := azureCertCredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
return accessToken.Token, accessToken.ExpiresOn, nil
}

func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureCLICredentials, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cacheKey := "cli:" + strings.Join(scopes, ",")
if token, expiresOn, found := client.getCachedCliToken(cacheKey); found {
tflog.Debug(ctx, "Using cached token for Azure CLI credential")
return token, expiresOn, nil
}

cred, err := client.getOrCreateCredential(ctx, credTypeCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}

client.setCachedCliToken(cacheKey, accessToken.Token, accessToken.ExpiresOn)
return accessToken.Token, accessToken.ExpiresOn, nil
}

func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureDeveloperCLICredentials, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cacheKey := "devcli:" + strings.Join(scopes, ",")
if token, expiresOn, found := client.getCachedCliToken(cacheKey); found {
tflog.Debug(ctx, "Using cached token for Azure Developer CLI credential")
return token, expiresOn, nil
}

cred, err := client.getOrCreateCredential(ctx, credTypeDevCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureDeveloperCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}

client.setCachedCliToken(cacheKey, accessToken.Token, accessToken.ExpiresOn)
return accessToken.Token, accessToken.ExpiresOn, nil
}

func (client *Auth) AuthenticateClientSecret(ctx context.Context, scopes []string) (string, time.Time, error) {
clientSecretCredential, err := azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret, &azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeClientSecret, func() (azcore.TokenCredential, error) {
return azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret,
&azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
})
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := clientSecretCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -182,32 +294,30 @@ func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestO
}

func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (string, time.Time, error) {
var creds []azcore.TokenCredential

oidcCred, err := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
cred, err := client.getOrCreateCredential(ctx, credTypeOIDC, func() (azcore.TokenCredential, error) {
oidcCred, oidcErr := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
if oidcErr != nil {
return nil, oidcErr
}

return azidentity.NewChainedTokenCredential([]azcore.TokenCredential{oidcCred}, nil)
})
if err != nil {
tflog.Error(ctx, fmt.Sprintf("newDefaultAzureCredential failed to initialize oidc credential:\n\t%s", err.Error()))
return "", time.Time{}, err
}
creds = append(creds, oidcCred)

chain, err := azidentity.NewChainedTokenCredential(creds, nil)
if err != nil {
return "", time.Time{}, err
}

accessToken, err := chain.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -216,17 +326,19 @@ func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (stri
}

func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
userManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeUserManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := userManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -235,16 +347,18 @@ func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes
}

func (client *Auth) AuthenticateSystemManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
systemManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeSystemManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := systemManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -266,23 +380,25 @@ func (client *Auth) AuthenticateAzDOWorkloadIdentityFederation(ctx context.Conte
return "", time.Time{}, errors.New("could not obtain an OIDC request token for Azure DevOps Workload Identity Federation")
}

azdoWorkloadIdentityCredential, err := azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeAzDOPipelines, func() (azcore.TokenCredential, error) {
return azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azdoWorkloadIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down
Loading
Loading