Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 145 additions & 78 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand All @@ -31,8 +32,30 @@ func (e *TokenExpiredError) Error() string {
return e.Message
}

type credentialType string

const (
credTypeClientSecret credentialType = "client_secret"
credTypeClientCertificate credentialType = "client_certificate"
credTypeCLI credentialType = "cli"
credTypeDevCLI credentialType = "dev_cli"
credTypeAzDOPipelines credentialType = "azdo_pipelines"
credTypeOIDC credentialType = "oidc"
credTypeUserManagedIdentity credentialType = "user_managed_identity"
credTypeSystemManagedIdentity credentialType = "system_managed_identity"
)

type credentialHolder struct {
credential azcore.TokenCredential
once sync.Once
err error
}

type Auth struct {
config *config.ProviderConfig

credentials map[credentialType]*credentialHolder
mu sync.RWMutex
}

type OidcCredential struct {
Expand All @@ -55,48 +78,83 @@ type OidcCredentialOptions struct {

func NewAuthBase(configValue *config.ProviderConfig) *Auth {
return &Auth{
config: configValue,
config: configValue,
credentials: make(map[credentialType]*credentialHolder),
}
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cert, key, err := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if err != nil {
return "", time.Time{}, err
func (client *Auth) getOrCreateCredential(ctx context.Context, credType credentialType, factory func() (azcore.TokenCredential, error)) (azcore.TokenCredential, error) {
client.mu.RLock()
holder, exists := client.credentials[credType]
client.mu.RUnlock()

if !exists {
client.mu.Lock()
holder, exists = client.credentials[credType]
if !exists {
holder = &credentialHolder{}
client.credentials[credType] = holder
tflog.Debug(ctx, fmt.Sprintf("Created credential holder for type: %s", credType))
}
client.mu.Unlock()
}

azureCertCredentials, err := azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
holder.once.Do(func() {
tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType))
holder.credential, holder.err = factory()
if holder.err != nil {
tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, holder.err.Error()))
} else {
tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType))
}
})

return holder.credential, holder.err
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cred, err := client.getOrCreateCredential(ctx, credTypeClientCertificate, func() (azcore.TokenCredential, error) {
cert, key, certErr := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if certErr != nil {
return nil, certErr
}

return azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}
accessToken, err := azureCertCredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
return accessToken.Token, accessToken.ExpiresOn, nil
}

func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureCLICredentials, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cred, err := client.getOrCreateCredential(ctx, credTypeCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -105,15 +163,17 @@ func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (
}

func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureDeveloperCLICredentials, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cred, err := client.getOrCreateCredential(ctx, credTypeDevCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureDeveloperCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -122,21 +182,24 @@ func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scop
}

func (client *Auth) AuthenticateClientSecret(ctx context.Context, scopes []string) (string, time.Time, error) {
clientSecretCredential, err := azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret, &azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeClientSecret, func() (azcore.TokenCredential, error) {
return azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret,
&azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
})
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := clientSecretCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -182,32 +245,30 @@ func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestO
}

func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (string, time.Time, error) {
var creds []azcore.TokenCredential

oidcCred, err := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
cred, err := client.getOrCreateCredential(ctx, credTypeOIDC, func() (azcore.TokenCredential, error) {
oidcCred, oidcErr := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
if oidcErr != nil {
return nil, oidcErr
}

return azidentity.NewChainedTokenCredential([]azcore.TokenCredential{oidcCred}, nil)
})
if err != nil {
tflog.Error(ctx, fmt.Sprintf("newDefaultAzureCredential failed to initialize oidc credential:\n\t%s", err.Error()))
return "", time.Time{}, err
}
creds = append(creds, oidcCred)

chain, err := azidentity.NewChainedTokenCredential(creds, nil)
if err != nil {
return "", time.Time{}, err
}

accessToken, err := chain.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -216,17 +277,19 @@ func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (stri
}

func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
userManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeUserManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := userManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -235,16 +298,18 @@ func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes
}

func (client *Auth) AuthenticateSystemManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
systemManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeSystemManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := systemManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -266,23 +331,25 @@ func (client *Auth) AuthenticateAzDOWorkloadIdentityFederation(ctx context.Conte
return "", time.Time{}, errors.New("could not obtain an OIDC request token for Azure DevOps Workload Identity Federation")
}

azdoWorkloadIdentityCredential, err := azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeAzDOPipelines, func() (azcore.TokenCredential, error) {
return azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azdoWorkloadIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down
Loading
Loading