diff --git a/src/VirtualClient/VirtualClient.Actions.FunctionalTests/GetAccessTokenProfileTests.cs b/src/VirtualClient/VirtualClient.Actions.FunctionalTests/GetAccessTokenProfileTests.cs new file mode 100644 index 0000000000..13b099c53a --- /dev/null +++ b/src/VirtualClient/VirtualClient.Actions.FunctionalTests/GetAccessTokenProfileTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace VirtualClient.Actions +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Threading; + using System.Threading.Tasks; + using Moq; + using NUnit.Framework; + using VirtualClient.Common; + using VirtualClient.Contracts; + + [TestFixture] + [Category("Functional")] + public class GetAccessTokenProfileTests + { + private DependencyFixture dependencyFixture; + + [OneTimeSetUp] + public void SetupFixture() + { + this.dependencyFixture = new DependencyFixture(); + ComponentTypeCache.Instance.LoadComponentTypes(TestDependencies.TestDirectory); + } + + [Test] + [TestCase("GET-ACCESS-TOKEN.json", PlatformID.Unix)] + [TestCase("GET-ACCESS-TOKEN.json", PlatformID.Win32NT)] + public void GetAccessTokenProfileParametersAreInlinedCorrectly(string profile, PlatformID platform) + { + this.dependencyFixture.Setup(platform); + using (ProfileExecutor executor = TestDependencies.CreateProfileExecutor(profile, this.dependencyFixture.Dependencies)) + { + WorkloadAssert.ParameterReferencesInlined(executor.Profile); + } + } + + [Test] + [TestCase("GET-ACCESS-TOKEN.json", PlatformID.Unix)] + [TestCase("GET-ACCESS-TOKEN.json", PlatformID.Win32NT)] + public void GetAccessTokenProfileParametersAreAvailable(string profile, PlatformID platform) + { + this.dependencyFixture.Setup(platform); + + var mandatoryParameters = new List { "KeyVaultUri", "TenantId" }; + using (ProfileExecutor executor = TestDependencies.CreateProfileExecutor(profile, this.dependencyFixture.Dependencies)) + { + Assert.IsEmpty(executor.Profile.Actions); + Assert.AreEqual(1, executor.Profile.Dependencies.Count); + + var dependencyBlock = executor.Profile.Dependencies.FirstOrDefault(); + + foreach (var parameters in mandatoryParameters) + { + Assert.IsTrue(dependencyBlock.Parameters.ContainsKey(parameters)); + } + } + } + } +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.Contracts/DependencyKeyVaultStore.cs b/src/VirtualClient/VirtualClient.Contracts/DependencyKeyVaultStore.cs index 3f20d96445..959431724d 100644 --- a/src/VirtualClient/VirtualClient.Contracts/DependencyKeyVaultStore.cs +++ b/src/VirtualClient/VirtualClient.Contracts/DependencyKeyVaultStore.cs @@ -53,4 +53,4 @@ public DependencyKeyVaultStore(string storeName, Uri endpointUri, TokenCredentia /// public TokenCredential Credentials { get; } } -} +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.Core.UnitTests/EndpointUtilityTests.cs b/src/VirtualClient/VirtualClient.Core.UnitTests/EndpointUtilityTests.cs index 960f157e7d..34fd05b29f 100644 --- a/src/VirtualClient/VirtualClient.Core.UnitTests/EndpointUtilityTests.cs +++ b/src/VirtualClient/VirtualClient.Core.UnitTests/EndpointUtilityTests.cs @@ -188,7 +188,7 @@ public void EndpointUtilityThrowsWhenCreatingBlobStoreReferenceForCDNUriIfUriIsV "https://anystorage.blob.core.windows.net/")] // [TestCase( - "https://anystorage.blob.core.windows.net?sv=2022-11-02&ss=b&srt=co&sp=rtf&se=2024-07-02T05:15:29Z&st=2024-07-01T21:15:29Z&spr=https", + "https://anystorage.blob.core.windows.net?sv=2022-11-02&ss=b&srt=co&sp=rtf&se=2024-07-02T05:15:29Z&st=2024-07-01T21:15:29Z&spr=https", "https://anystorage.blob.core.windows.net/?sv=2022-11-02&ss=b&srt=co&sp=rtf&se=2024-07-02T05:15:29Z&st=2024-07-01T21:15:29Z&spr=https")] // [TestCase( @@ -230,7 +230,7 @@ public void EndpointUtilityCreatesTheExpectedBlobStoreReferenceForConnectionStri [Test] [TestCase("https://any.service.azure.com?miid=307591a4-abb2-4559-af59-b47177d140cf", "https://any.service.azure.com")] - [TestCase("https://any.service.azure.com/?miid=307591a4-abb2-4559-af59-b47177d140cf","https://any.service.azure.com/")] + [TestCase("https://any.service.azure.com/?miid=307591a4-abb2-4559-af59-b47177d140cf", "https://any.service.azure.com/")] public void EndpointUtilityCreatesTheExpectedBlobStoreReferenceForUrisReferencingManagedIdentities(string uri, string expectedUri) { DependencyBlobStore store = EndpointUtility.CreateBlobStoreReference( @@ -338,7 +338,7 @@ public void EndpointUtilityCreatesTheExpectedBlobStoreReferenceForConnectionStri Assert.IsNotNull(store.Credentials); Assert.IsInstanceOf(store.Credentials); } - + [Test] [TestCase("https://any.service.azure.com/?cid=307591a4-abb2-4559-af59-b47177d140cf&tid=985bbc17-e3a5-4fec-b0cb-40dbb8bc5959&crti=ABC&crts=any.domain.com", "https://any.service.azure.com/")] [TestCase("https://any.service.azure.com/?cid=307591a4-abb2-4559-af59-b47177d140cf&tid=985bbc17-e3a5-4fec-b0cb-40dbb8bc5959&crti=ABC CA 01&crts=any.domain.com", "https://any.service.azure.com/")] @@ -854,5 +854,37 @@ public void CreateKeyVaultStoreReference_ConnectionString_ThrowsOnInvalid() "InvalidConnectionString", this.mockFixture.CertificateManager.Object)); } + + [Test] + [TestCase("https://anyvault.vault.azure.net/?cid=123456&tid=654321")] + [TestCase("https://anycontentstorage.blob.core.windows.net?cid=123456&tid=654321")] + [TestCase("https://anypackagestorage.blob.core.windows.net?tid=654321")] + [TestCase("https://anynamespace.servicebus.windows.net?cid=123456&tid=654321")] + [TestCase("https://my-keyvault.vault.azure.net/;tid=654321")] + public void TryParseMicrosoftEntraTenantIdReference_Uri_WorksAsExpected(string input) + { + // Arrange + Uri uri = new Uri(input); + bool result = EndpointUtility.TryParseMicrosoftEntraTenantIdReference(uri, out string actualTenantId); + + // Assert + Assert.True(result); + Assert.AreEqual("654321", actualTenantId); + } + + [Test] + [TestCase("https://anycontentstorage.blob.core.windows.net?cid=123456&tenantId=654321")] + [TestCase("https://anypackagestorage.blob.core.windows.net?miid=654321")] + [TestCase("https://my-keyvault.vault.azure.net/;cid=654321")] + public void TryParseMicrosoftEntraTenantIdReference_Uri_ReturnFalseWhenInvalid(string input) + { + // Arrange + Uri uri = new Uri(input); + bool result = EndpointUtility.TryParseMicrosoftEntraTenantIdReference(uri, out string actualTenantId); + + // Assert + Assert.IsFalse(result); + Assert.IsNull(actualTenantId); + } } } diff --git a/src/VirtualClient/VirtualClient.Core.UnitTests/KeyVaultManagerTests.cs b/src/VirtualClient/VirtualClient.Core.UnitTests/KeyVaultManagerTests.cs index c3c6c6ac40..91aaf29636 100644 --- a/src/VirtualClient/VirtualClient.Core.UnitTests/KeyVaultManagerTests.cs +++ b/src/VirtualClient/VirtualClient.Core.UnitTests/KeyVaultManagerTests.cs @@ -115,20 +115,13 @@ public async Task KeyVaultManagerReturnsExpectedKey() } [Test] - [TestCase(true)] - [TestCase(false)] - public async Task KeyVaultManagerReturnsExpectedCertificate(bool retrieveWithPrivateKey) + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task KeyVaultManagerReturnsExpectedCertificate(PlatformID platform) { - var result = await this.keyVaultManager.GetCertificateAsync("mycert", CancellationToken.None, "https://myvault.vault.azure.net/", retrieveWithPrivateKey); + var result = await this.keyVaultManager.GetCertificateAsync(platform, "mycert", CancellationToken.None, "https://myvault.vault.azure.net/"); Assert.IsNotNull(result); - if (retrieveWithPrivateKey) - { - Assert.IsTrue(result.HasPrivateKey); - } - else - { - Assert.IsFalse(result.HasPrivateKey); - } + Assert.IsTrue(result.HasPrivateKey); } [Test] diff --git a/src/VirtualClient/VirtualClient.Core/EndpointUtility.cs b/src/VirtualClient/VirtualClient.Core/EndpointUtility.cs index 9294acffa1..dff6394fb1 100644 --- a/src/VirtualClient/VirtualClient.Core/EndpointUtility.cs +++ b/src/VirtualClient/VirtualClient.Core/EndpointUtility.cs @@ -398,6 +398,26 @@ public static bool TryParseCertificateReference(Uri uri, out string issuer, out return TryGetCertificateReferenceForUri(queryParameters, out issuer, out subject); } + /// + /// Tries to parse the Microsoft Entra reference information from the provided uri. If the uri does not contain the correctly formatted client ID + /// and tenant ID information the method will return false, and keep the two out parameters as null. + /// Ex. https://anystore.blob.core.windows.net?cid={clientId};tid={tenantId} + /// + /// The uri to attempt to parse the values from. + /// The tenant ID from the Microsoft Entra reference. + /// True/False if the method was able to successfully parse both the client ID and the tenant ID from the Microsoft Entra reference. + public static bool TryParseMicrosoftEntraTenantIdReference(Uri uri, out string tenantId) + { + string queryString = Uri.UnescapeDataString(uri.Query).Trim('?').Replace("&", ",,,"); + + IDictionary queryParameters = TextParsingExtensions.ParseDelimitedValues(queryString)?.ToDictionary( + entry => entry.Key, + entry => entry.Value?.ToString(), + StringComparer.OrdinalIgnoreCase); + + return TryGetMicrosoftEntraTenantId(queryParameters, out tenantId); + } + /// /// Returns the endpoint by verifying package uri checks. /// if the endpoint is a package uri without http or https protocols then append the protocol else return the endpoint value. @@ -1292,5 +1312,23 @@ private static bool TryGetMicrosoftEntraReferenceForUri(IDictionary uriParameters, out string tenantId) + { + bool parametersDefined = false; + tenantId = null; + + if (uriParameters?.Any() == true) + { + if (uriParameters.TryGetValue(UriParameter.TenantId, out string microsoftEntraTenantId) + && !string.IsNullOrWhiteSpace(microsoftEntraTenantId)) + { + tenantId = microsoftEntraTenantId; + parametersDefined = true; + } + } + + return parametersDefined; + } } } diff --git a/src/VirtualClient/VirtualClient.Core/IKeyVaultManager.cs b/src/VirtualClient/VirtualClient.Core/IKeyVaultManager.cs index 4a4d19329a..13d4a2b8c4 100644 --- a/src/VirtualClient/VirtualClient.Core/IKeyVaultManager.cs +++ b/src/VirtualClient/VirtualClient.Core/IKeyVaultManager.cs @@ -3,6 +3,7 @@ namespace VirtualClient { + using System; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -60,10 +61,10 @@ Task GetKeyAsync( /// /// Retrieves a certificate from the Azure Key Vault. /// + /// The operating system platform (e.g. Windows, Linux). /// The name of the certificate to be retrieved /// A token that can be used to cancel the operation. /// The URI of the Azure Key Vault. - /// flag to decode whether to retrieve certificate with private key /// A policy to use for handling retries when transient errors/failures happen. /// /// A containing the certificate. @@ -72,10 +73,10 @@ Task GetKeyAsync( /// Thrown if the certificate is not found, access is denied, or another error occurs. /// Task GetCertificateAsync( + PlatformID platform, string certName, CancellationToken cancellationToken, string keyVaultUri = null, - bool retrieveWithPrivateKey = false, IAsyncPolicy retryPolicy = null); } } diff --git a/src/VirtualClient/VirtualClient.Core/Identity/AccessTokenCredential.cs b/src/VirtualClient/VirtualClient.Core/Identity/AccessTokenCredential.cs new file mode 100644 index 0000000000..8dbf2e87d6 --- /dev/null +++ b/src/VirtualClient/VirtualClient.Core/Identity/AccessTokenCredential.cs @@ -0,0 +1,58 @@ +namespace VirtualClient.Identity +{ + using System; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core; + using VirtualClient.Common.Extensions; + + /// + /// A implementation that uses a pre-acquired + /// access token. + /// + public class AccessTokenCredential : TokenCredential + { + /// + /// Creates a new instance of the class. + /// + /// + /// The credential provider that will be used to get access tokens. + /// + public AccessTokenCredential(string token) + { + token.ThrowIfNull(nameof(token)); + this.AccessToken = new AccessToken(token, DateTimeOffset.UtcNow.AddHours(1)); + } + + /// + /// The access token to use for authentication. + /// + public AccessToken AccessToken { get; } + + /// + /// Gets an access token using the underlying credentials. + /// + /// Context information used when getting the access token. + /// A token that can be used to cancel the operation. + /// + /// An access token that can be used to authenticate with Azure resources. + /// + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return this.AccessToken; + } + + /// + /// Gets an access token using the underlying credentials. + /// + /// Context information used when getting the access token. + /// A token that can be used to cancel the operation. + /// + /// An access token that can be used to authenticate with Azure resources. + /// + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return new ValueTask(this.AccessToken); + } + } +} diff --git a/src/VirtualClient/VirtualClient.Core/KeyVaultManager.cs b/src/VirtualClient/VirtualClient.Core/KeyVaultManager.cs index de1b82a6a6..4f26375608 100644 --- a/src/VirtualClient/VirtualClient.Core/KeyVaultManager.cs +++ b/src/VirtualClient/VirtualClient.Core/KeyVaultManager.cs @@ -15,7 +15,6 @@ namespace VirtualClient using Azure.Security.KeyVault.Secrets; using Polly; using VirtualClient.Common.Extensions; - using VirtualClient.Contracts; /// /// Provides methods for retrieving secrets, keys, and certificates from an Azure Key Vault. @@ -190,10 +189,10 @@ public async Task GetKeyAsync( /// /// Retrieves a certificate from the Azure Key Vault. /// + /// The operating system platform. /// The name of the certificate to be retrieved /// A token that can be used to cancel the operation. /// The URI of the Azure Key Vault. - /// flag to decode whether to retrieve certificate with private key /// A policy to use for handling retries when transient errors/failures happen. /// /// A containing the certificate @@ -202,13 +201,12 @@ public async Task GetKeyAsync( /// Thrown if the certificate is not found, access is denied, or another error occurs. /// public async Task GetCertificateAsync( + PlatformID platform, string certName, CancellationToken cancellationToken, string keyVaultUri = null, - bool retrieveWithPrivateKey = false, IAsyncPolicy retryPolicy = null) { - this.ValidateKeyVaultStore(); this.StoreDescription.ThrowIfNull(nameof(this.StoreDescription)); certName.ThrowIfNullOrWhiteSpace(nameof(certName), "The certificate name cannot be null or empty."); @@ -217,37 +215,47 @@ public async Task GetCertificateAsync( ? new Uri(keyVaultUri) : ((DependencyKeyVaultStore)this.StoreDescription).EndpointUri; - CertificateClient client = this.CreateCertificateClient(vaultUri, ((DependencyKeyVaultStore)this.StoreDescription).Credentials); - try { - return await (retryPolicy ?? KeyVaultManager.DefaultRetryPolicy).ExecuteAsync(async () => + KeyVaultSecret keyVaultSecret = await (retryPolicy ?? KeyVaultManager.DefaultRetryPolicy).ExecuteAsync(async () => { - // Get the full certificate with private key (PFX) if requested - if (retrieveWithPrivateKey) - { - X509Certificate2 privateKeyCert = await client - .DownloadCertificateAsync(certName, cancellationToken: cancellationToken) - .ConfigureAwait(false); + SecretClient secretsClient = new SecretClient(vaultUri, ((DependencyKeyVaultStore)this.StoreDescription).Credentials); + Response response = await secretsClient.GetSecretAsync(certName, version: null, cancellationToken); - if (privateKeyCert is null || !privateKeyCert.HasPrivateKey) - { - throw new DependencyException("Failed to retrieve certificate content with private key."); - } + return response.Value; + }).ConfigureAwait(false); + + byte[] privateKeyBytes = Convert.FromBase64String(keyVaultSecret.Value); + X509Certificate2 certificate = null; + + var keyStorageFlags = X509KeyStorageFlags.MachineKeySet | X509KeyStorageFlags.PersistKeySet; - return privateKeyCert; - } - else - { - // If private key not needed, load cert from PublicBytes - KeyVaultCertificateWithPolicy cert = await client.GetCertificateAsync(certName, cancellationToken: cancellationToken); #if NET9_0_OR_GREATER - return X509CertificateLoader.LoadCertificate(cert.Cer); + if (platform == PlatformID.Unix) + { + certificate = X509CertificateLoader.LoadPkcs12(privateKeyBytes, null, X509KeyStorageFlags.PersistKeySet); + } + else if (platform == PlatformID.Win32NT) + { + certificate = X509CertificateLoader.LoadPkcs12(privateKeyBytes, null, keyStorageFlags); + } #elif NET8_0_OR_GREATER - return new X509Certificate2(cert.Cer); + if (platform == PlatformID.Unix) + { + certificate = new X509Certificate2(privateKeyBytes, (string)null, X509KeyStorageFlags.PersistKeySet); + } + else if (platform == PlatformID.Win32NT) + { + certificate = new X509Certificate2(privateKeyBytes, (string)null, keyStorageFlags); + } #endif - } - }).ConfigureAwait(false); + + if (certificate is null || !certificate.HasPrivateKey) + { + throw new DependencyException("Failed to retrieve certificate content with private key."); + } + + return certificate; } catch (RequestFailedException ex) when (ex.Status == (int)HttpStatusCode.Forbidden) { diff --git a/src/VirtualClient/VirtualClient.Dependencies.UnitTests/KeyVaultAccessTokenTests.cs b/src/VirtualClient/VirtualClient.Dependencies.UnitTests/KeyVaultAccessTokenTests.cs new file mode 100644 index 0000000000..0e4a1ded68 --- /dev/null +++ b/src/VirtualClient/VirtualClient.Dependencies.UnitTests/KeyVaultAccessTokenTests.cs @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace VirtualClient.Dependencies +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core; + using Azure.Identity; + using Moq; + using NUnit.Framework; + using VirtualClient.Common.Telemetry; + + [TestFixture] + [Category("Unit")] + public class KeyVaultAccessTokenTests + { + private MockFixture mockFixture; + + [SetUp] + public void Setup() + { + this.mockFixture = new MockFixture(); + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task InitializeWillNotDoAnythingIfLogFileNameIsNotProvided(PlatformID platform) + { + this.mockFixture.Setup(platform); + + this.SetupWorkingDirectory(platform, out _); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + await component.InitializeAsyncInternal(EventContext.None, CancellationToken.None).ConfigureAwait(false); + + Assert.IsNull(component.AccessTokenPathInternal); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task InitializeWillEnsureAccessTokenPathIsReadyIfLogFileNameIsProvided(PlatformID platform) + { + this.mockFixture.Setup(platform); + + this.SetupWorkingDirectory(platform, out string workingDir); + + string expectedPath = this.Combine(workingDir, "AccessToken.txt"); + + // Setup: file does not exist initially + this.mockFixture.File.Setup(f => f.Exists(expectedPath)).Returns(false); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.Parameters["LogFileName"] = "AccessToken.txt"; + + await component.InitializeAsyncInternal(EventContext.None, CancellationToken.None).ConfigureAwait(false); + + Assert.AreEqual(expectedPath, component.AccessTokenPathInternal); + Assert.IsFalse(this.mockFixture.File.Object.Exists(component.AccessTokenPathInternal), "File should not be created during Initialize."); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task InitializeWillEnsureOldFileIsDeletedIfPresent(PlatformID platform) + { + this.mockFixture.Setup(platform); + + this.SetupWorkingDirectory(platform, out string workingDir); + + string tokenPath = this.Combine(workingDir, "AccessToken.txt"); + + // Setup: existing token file is present and should be deleted during Initialize. + this.mockFixture.File.Setup(f => f.Exists(tokenPath)).Returns(true); + + bool deleteCalled = false; + this.mockFixture.FileSystem + .Setup(f => f.File.Delete(It.IsAny())) + .Callback(() => deleteCalled = true); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.Parameters["LogFileName"] = "AccessToken.txt"; + + await component.InitializeAsyncInternal(EventContext.None, CancellationToken.None).ConfigureAwait(false); + + Assert.IsTrue(deleteCalled, "Existing token file should be deleted."); + this.mockFixture.File.Verify(f => f.Delete(It.IsAny()), Times.Once); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public void ExecuteAsyncValidatesRequiredParameters(PlatformID platform) + { + this.mockFixture.Setup(platform); + + var parameters = new Dictionary(StringComparer.OrdinalIgnoreCase); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, parameters)) + { + Assert.ThrowsAsync(() => component.ExecuteAsync(CancellationToken.None)); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task ExecuteAsyncWillWriteTokenToFileWhenLogFileNameIsProvided(PlatformID platform) + { + this.mockFixture.Setup(platform); + this.SetupWorkingDirectory(platform, out string workingDir); + + string tokenContent = Guid.NewGuid().ToString(); + string expectedPath = this.Combine(workingDir, "AccessToken.txt"); + + Mock mockFileStream = new Mock(); + this.mockFixture.FileStream.Setup(f => f.New(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(mockFileStream.Object) + .Callback((string path, FileMode mode, FileAccess access, FileShare share) => + { + Assert.AreEqual(expectedPath, path); + Assert.AreEqual(FileMode.Create, mode); + Assert.AreEqual(FileAccess.ReadWrite, access); + Assert.AreEqual(FileShare.ReadWrite, share); + }); + + mockFileStream + .Setup(x => x.Write(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((byte[] data, int offset, int count) => + { + byte[] byteData = Encoding.Default.GetBytes(tokenContent); + Assert.AreEqual(0, offset); + Assert.AreEqual(byteData.Length, count); + CollectionAssert.AreEqual(byteData, data); + }); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.Parameters["LogFileName"] = "AccessToken.txt"; + component.InteractiveTokenToReturn = tokenContent; + + await component.InitializeAsyncInternal(EventContext.None, CancellationToken.None).ConfigureAwait(false); + await component.ExecuteAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public void GetTokenRequestContextWillReturnCorrectValue(PlatformID platform) + { + this.mockFixture.Setup(platform); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + TokenRequestContext ctx = component.GetTokenRequestContextInternal(); + + Assert.IsNotNull(ctx); + Assert.AreEqual(1, ctx.Scopes.Length); + Assert.AreEqual("https://myvault.vault.azure.net/.default", ctx.Scopes[0]); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task ExecuteAsyncWillUseInteractiveTokenFirst(PlatformID platform) + { + this.mockFixture.Setup(platform); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.InteractiveTokenToReturn = "interactive-ok"; + + await component.ExecuteAsync(CancellationToken.None).ConfigureAwait(false); + + Assert.AreEqual(1, component.InteractiveCalls); + Assert.AreEqual(0, component.DeviceCodeCalls); + } + } + + [Test] + [TestCase(PlatformID.Unix)] + [TestCase(PlatformID.Win32NT)] + public async Task ExecuteAsyncWillUseDeviceLoginIfInteractiveFailsWithExactError(PlatformID platform) + { + this.mockFixture.Setup(platform); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.ThrowBrowserUnavailableAuthenticationFailedException = true; + component.DeviceCodeTokenToReturn = "device-code-ok"; + + await component.ExecuteAsync(CancellationToken.None).ConfigureAwait(false); + + Assert.AreEqual(1, component.InteractiveCalls); + Assert.AreEqual(1, component.DeviceCodeCalls); + } + } + + [Test] + [TestCase(PlatformID.Unix, null)] + [TestCase(PlatformID.Win32NT, null)] + [TestCase(PlatformID.Unix, "")] + [TestCase(PlatformID.Win32NT, "")] + [TestCase(PlatformID.Unix, " ")] + [TestCase(PlatformID.Win32NT, " ")] + [TestCase(PlatformID.Unix, "validToken")] + [TestCase(PlatformID.Win32NT, "validToken")] + public void ExecuteAsyncThrowsErrorIfTokenIsNullOrWhitespace(PlatformID platform, string token) + { + this.mockFixture.Setup(platform); + + using (TestKeyVaultAccessToken component = new TestKeyVaultAccessToken(this.mockFixture.Dependencies, this.CreateDefaultParameters())) + { + component.InteractiveTokenToReturn = token; + + if (string.IsNullOrWhiteSpace(token)) + { + Assert.ThrowsAsync(() => component.ExecuteAsync(CancellationToken.None)); + } + else + { + Assert.DoesNotThrowAsync(() => component.ExecuteAsync(CancellationToken.None)); + } + } + } + + private void SetupWorkingDirectory(PlatformID platform, out string workingDir) + { + workingDir = platform == PlatformID.Win32NT ? @"C:\home\user" : "/home/user"; + + // KeyVaultAccessToken uses ISystemManagement.FileSystem internally, which in unit tests is MockFixture.FileSystem + this.mockFixture.Directory.Setup(d => d.GetCurrentDirectory()).Returns(workingDir); + } + + private string Combine(string left, string right) + { + // Avoid relying on host OS behavior; use the path separator expected by the test platform. + char sep = this.mockFixture.Platform == PlatformID.Win32NT ? '\\' : '/'; + return $"{left.TrimEnd(sep)}{sep}{right.TrimStart(sep)}"; + } + + private IDictionary CreateDefaultParameters() + { + return new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "TenantId", "00000000-0000-0000-0000-000000000000" }, + { "KeyVaultUri", "https://myvault.vault.azure.net/" } + }; + } + + private sealed class TestKeyVaultAccessToken : KeyVaultAccessToken + { + public TestKeyVaultAccessToken(Microsoft.Extensions.DependencyInjection.IServiceCollection dependencies, IDictionary parameters) + : base(dependencies, parameters) + { + } + + public string InteractiveTokenToReturn { get; set; } = "interactive-token"; + + public string DeviceCodeTokenToReturn { get; set; } = "device-token"; + + public bool ThrowBrowserUnavailableAuthenticationFailedException { get; set; } + + public int InteractiveCalls { get; private set; } + + public int DeviceCodeCalls { get; private set; } + + public string AccessTokenPathInternal => this.AccessTokenPath; + + public Task InitializeAsyncInternal(EventContext context, CancellationToken token) + { + return this.InitializeAsync(context, token); + } + + public TokenRequestContext GetTokenRequestContextInternal() + { + return this.GetTokenRequestContext(); + } + + protected override async Task AcquireInteractiveTokenAsync( + TokenCredential credential, + TokenRequestContext requestContext, + CancellationToken cancellationToken) + { + this.InteractiveCalls++; + + if (this.ThrowBrowserUnavailableAuthenticationFailedException) + { + throw new AuthenticationFailedException("Unable to open a web page"); + } + + await Task.Yield(); + return this.InteractiveTokenToReturn; + } + + protected override async Task AcquireDeviceCodeTokenAsync( + TokenCredential credential, + TokenRequestContext requestContext, + CancellationToken cancellationToken) + { + this.DeviceCodeCalls++; + await Task.Yield(); + return this.DeviceCodeTokenToReturn; + } + } + } +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.Dependencies/CertificateInstallation.cs b/src/VirtualClient/VirtualClient.Dependencies/CertificateInstallation.cs new file mode 100644 index 0000000000..e9dd3e2c46 --- /dev/null +++ b/src/VirtualClient/VirtualClient.Dependencies/CertificateInstallation.cs @@ -0,0 +1,256 @@ +namespace VirtualClient.Dependencies +{ + using System; + using System.Collections.Generic; + using System.IO.Abstractions; + using System.Security.Cryptography.X509Certificates; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.CodeAnalysis; + using Microsoft.Extensions.DependencyInjection; + using VirtualClient.Common; + using VirtualClient.Common.Extensions; + using VirtualClient.Common.Telemetry; + using VirtualClient.Contracts; + using VirtualClient.Identity; + + /// + /// Virtual Client component that acquires an Azure access token for the specified Key Vault + /// using interactive browser authentication with a device-code fallback. + /// + public class CertificateInstallation : VirtualClientComponent + { + private ISystemManagement systemManagement; + private IFileSystem fileSystem; + private ProcessManager processManager; + + /// + /// Initializes a new instance of the class. + /// + /// Provides all of the required dependencies to the Virtual Client component. + /// Parameters to the Virtual Client component. + public CertificateInstallation(IServiceCollection dependencies, IDictionary parameters = null) + : base(dependencies, parameters) + { + this.systemManagement = dependencies.GetService(); + this.fileSystem = this.systemManagement.FileSystem; + this.processManager = this.systemManagement.ProcessManager; + } + + /// + /// Gets the Azure tenant ID used to acquire an access token. + /// + protected string TenantId + { + get + { + return this.Parameters.GetValue(nameof(this.TenantId)); + } + } + + /// + /// Gets the Azure Key Vault URI for which the access token will be requested. + /// Example: https://anyvault.vault.azure.net/ + /// + protected string KeyVaultUri + { + get + { + return this.Parameters.GetValue(nameof(this.KeyVaultUri)); + } + } + + /// + /// The name of the certificate to be retrieved + /// + protected string CertificateName + { + get + { + return this.Parameters.GetValue(nameof(this.CertificateName)); + } + } + + /// + /// Gets the access token used to authenticate with Azure services. + /// + protected string AccessToken { get; set; } + + /// + /// Gets the path to the file where the access token is saved. + /// + protected string AccessTokenPath + { + get + { + return this.Parameters.GetValue(nameof(this.AccessTokenPath)); + } + } + + /// + /// + /// + /// + /// + /// + protected override async Task InitializeAsync(EventContext telemetryContext, CancellationToken cancellationToken) + { + this.AccessToken = this.Parameters.GetValue(nameof(this.AccessToken), string.Empty); + + if (string.IsNullOrWhiteSpace(this.AccessToken) && !string.IsNullOrWhiteSpace(this.AccessTokenPath)) + { + this.AccessToken = await this.fileSystem.File.ReadAllTextAsync(this.AccessTokenPath); + } + } + + /// + /// Acquires an access token for the configured Key Vault URI using Azure Identity. + /// The component attempts interactive browser authentication first and falls back to + /// device-code authentication when a browser is not available (e.g. headless Linux). + /// The token is always written to standard output. Token is also written to a file if AccessTokenPath is resolved. + /// + protected override async Task ExecuteAsync(EventContext telemetryContext, CancellationToken cancellationToken) + { + this.CertificateName.ThrowIfNullOrWhiteSpace(nameof(this.CertificateName)); + + try + { + IKeyVaultManager keyVault = this.GetKeyVaultManager(); + X509Certificate2 certificate = await keyVault.GetCertificateAsync(this.Platform, this.CertificateName, cancellationToken); + + if (this.Platform == PlatformID.Win32NT) + { + await this.InstallCertificateOnWindowsAsync(certificate, cancellationToken); + } + else if (this.Platform == PlatformID.Unix) + { + await this.InstallCertificateOnUnixAsync(certificate, cancellationToken); + } + else + { + throw new PlatformNotSupportedException($"The '{nameof(CertificateInstallation)}' component is not supported on platform '{this.Platform}'."); + } + } + catch (Exception exc) + { + throw new DependencyException( + $"An error occurred installing the certificate '{this.CertificateName}' from Key Vault. See inner exception for details.", + exc); + } + } + + /// + /// Installs the certificate in the appropriate certificate store on a Windows system. + /// + protected virtual Task InstallCertificateOnWindowsAsync(X509Certificate2 certificate, CancellationToken cancellationToken) + { + return Task.Run(() => + { + Console.WriteLine($"Certificate Store = CurrentUser/Personal"); + using (X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite)) + { + store.Open(OpenFlags.ReadWrite); + store.Add(certificate); + store.Close(); + } + }); + } + + /// + /// Installs the certificate in the appropriate certificate store on a Unix/Linux system. + /// + protected virtual async Task InstallCertificateOnUnixAsync(X509Certificate2 certificate, CancellationToken cancellationToken) + { + // On Unix/Linux systems, we install the certificate in the default location for the + // user as well as in a static location. In the future we will likely use the static location + // only. + string certificateDirectory = null; + + try + { + // When "sudo" is used to run the installer, we need to know the logged + // in user account. On Linux systems, there is an environment variable 'SUDO_USER' + // that defines the logged in user. + + string user = this.GetEnvironmentVariable(EnvironmentVariable.USER); + string sudoUser = this.GetEnvironmentVariable(EnvironmentVariable.SUDO_USER); + certificateDirectory = $"/home/{user}/.dotnet/corefx/cryptography/x509stores/my"; + + if (!string.IsNullOrWhiteSpace(sudoUser)) + { + // The installer is being executed with "sudo" privileges. We want to use the + // logged in user profile vs. "root". + certificateDirectory = $"/home/{sudoUser}/.dotnet/corefx/cryptography/x509stores/my"; + } + else if (user == "root") + { + // The installer is being executed from the "root" account on Linux. + certificateDirectory = $"/root/.dotnet/corefx/cryptography/x509stores/my"; + } + + Console.WriteLine($"Certificate Store = {certificateDirectory}"); + + if (!this.fileSystem.Directory.Exists(certificateDirectory)) + { + this.fileSystem.Directory.CreateDirectory(certificateDirectory); + } + + using (X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite)) + { + store.Open(OpenFlags.ReadWrite); + store.Add(certificate); + store.Close(); + } + + await this.fileSystem.File.WriteAllBytesAsync( + this.Combine(certificateDirectory, $"{certificate.Thumbprint}.pfx"), + certificate.Export(X509ContentType.Pfx)); + + // Permissions 777 (-rwxrwxrwx) + // https://linuxhandbook.com/linux-file-permissions/ + // + // User = read, write, execute + // Group = read, write, execute + // Other = read, write, execute+ + using (IProcessProxy process = this.processManager.CreateProcess("chmod", $"-R 777 {certificateDirectory}")) + { + await process.StartAndWaitAsync(cancellationToken); + process.ThrowIfErrored(); + } + } + catch (UnauthorizedAccessException) + { + throw new UnauthorizedAccessException( + $"Access permissions denied for certificate directory '{certificateDirectory}'. Execute the installer with " + + $"sudo/root privileges to install SDK certificates in privileged locations."); + } + } + + /// + /// Gets the Key Vault manager to use to retrieve certificates from Key Vault. + /// + protected IKeyVaultManager GetKeyVaultManager() + { + IKeyVaultManager keyVaultManager = this.Dependencies.GetService(); + keyVaultManager.ThrowIfNull(nameof(keyVaultManager)); + + if (keyVaultManager.StoreDescription != null) + { + return keyVaultManager; + } + else if (!string.IsNullOrWhiteSpace(this.AccessToken)) + { + this.KeyVaultUri.ThrowIfNullOrWhiteSpace(nameof(this.KeyVaultUri)); + + AccessTokenCredential tokenCredential = new AccessTokenCredential(this.AccessToken); + + DependencyKeyVaultStore dependencyKeyVault = new DependencyKeyVaultStore(DependencyStore.KeyVault, new Uri(this.KeyVaultUri), tokenCredential); + return new KeyVaultManager(dependencyKeyVault); + } + else + { + throw new InvalidOperationException($"The Key Vault manager has not been properly initialized. The '{nameof(this.LogFileName)}' parameter must be provided to read the access token from file."); + } + } + } +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.Dependencies/KeyVaultAccessToken.cs b/src/VirtualClient/VirtualClient.Dependencies/KeyVaultAccessToken.cs new file mode 100644 index 0000000000..166026fafe --- /dev/null +++ b/src/VirtualClient/VirtualClient.Dependencies/KeyVaultAccessToken.cs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace VirtualClient.Dependencies +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.IO.Abstractions; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core; + using Azure.Identity; + using Microsoft.Extensions.DependencyInjection; + using VirtualClient.Common.Extensions; + using VirtualClient.Common.Telemetry; + using VirtualClient.Contracts; + + /// + /// Virtual Client component that acquires an Azure access token for the specified Key Vault + /// using interactive browser authentication with a device-code fallback. + /// + public class KeyVaultAccessToken : VirtualClientComponent + { + private IFileSystem fileSystem; + + /// + /// Initializes a new instance of the class. + /// + /// Provides all of the required dependencies to the Virtual Client component. + /// Parameters to the Virtual Client component. + public KeyVaultAccessToken(IServiceCollection dependencies, IDictionary parameters = null) + : base(dependencies, parameters) + { + this.fileSystem = dependencies.GetService(); + this.fileSystem.ThrowIfNull(nameof(this.fileSystem)); + } + + /// + /// Gets the Azure Key Vault URI for which the access token will be requested. + /// Example: https://anyvault.vault.azure.net/ + /// + protected Uri KeyVaultUri { get; set; } + + /// + /// Gets the Azure tenant ID used to acquire an access token. + /// + protected string TenantId { get; set; } + + /// + /// Gets or sets the full file path where the acquired access token will be written when file logging is enabled. + /// This is resolved during when + /// is provided. + /// + protected string AccessTokenPath { get; set; } + + /// + /// Resolves the access token output file path + /// and removes any existing token file so the current run produces a fresh token output. + /// + protected override async Task InitializeAsync(EventContext telemetryContext, CancellationToken cancellationToken) + { + if (!string.IsNullOrWhiteSpace(this.LogFileName)) + { + string directory = !string.IsNullOrWhiteSpace(this.LogFolderName) + ? this.LogFolderName + : this.fileSystem.Directory.GetCurrentDirectory(); + + this.AccessTokenPath = this.Combine(directory, this.LogFileName); + + if (this.fileSystem.File.Exists(this.AccessTokenPath)) + { + await this.fileSystem.File.DeleteAsync(this.AccessTokenPath); + } + } + } + + /// + /// Acquires an access token for the configured Key Vault URI using Azure Identity. + /// The component attempts interactive browser authentication first and falls back to + /// device-code authentication when a browser is not available (e.g. headless Linux). + /// The token is always written to standard output. Token is also written to a file if AccessTokenPath is resolved. + /// + protected override async Task ExecuteAsync(EventContext telemetryContext, CancellationToken cancellationToken) + { + this.KeyVaultUri = new Uri(this.Parameters.GetValue(nameof(this.KeyVaultUri))); + this.KeyVaultUri.ThrowIfNull(nameof(this.KeyVaultUri)); + + this.TenantId = this.Parameters.GetValue(nameof(this.TenantId)); + if (string.IsNullOrWhiteSpace(this.TenantId)) + { + EndpointUtility.TryParseMicrosoftEntraTenantIdReference(this.KeyVaultUri, out string tenant); + this.TenantId = tenant; + } + + this.TenantId.ThrowIfNullOrWhiteSpace(nameof(this.TenantId)); + + string accessToken = null; + if (!cancellationToken.IsCancellationRequested) + { + TokenRequestContext requestContext = this.GetTokenRequestContext(); + try + { + // Attempt an interactive (browser-based) authentication first. On most Windows environments + // this will work and is the most convenient for the user. On many Linux systems, there may + // not be a GUI and thus no browser. In that case, we fall back to the device code credential + // option in the catch block below. + InteractiveBrowserCredential credential = new InteractiveBrowserCredential( + new InteractiveBrowserCredentialOptions + { + TenantId = this.TenantId + }); + + accessToken = await this.AcquireInteractiveTokenAsync(credential, requestContext, cancellationToken); + } + catch (AuthenticationFailedException exc) when (exc.Message.Contains("Unable to open a web page")) + { + // Browser-based authentication is unavailable; switch to device code flow and present + // the user with a code and URL to complete authentication from another device. + DeviceCodeCredential credential = new DeviceCodeCredential(new DeviceCodeCredentialOptions + { + TenantId = this.TenantId, + DeviceCodeCallback = (codeInfo, token) => + { + Console.WriteLine(string.Empty); + Console.WriteLine("Browser-based authentication unavailable (e.g. no GUI). Using device/code option."); + Console.WriteLine(string.Empty); + Console.WriteLine("********************** Azure Key Vault Authorization **********************"); + Console.WriteLine(string.Empty); + Console.WriteLine(codeInfo.Message); + Console.WriteLine(string.Empty); + Console.WriteLine("***************************************************************************"); + Console.WriteLine(string.Empty); + + return Task.CompletedTask; + } + }); + + accessToken = await this.AcquireDeviceCodeTokenAsync(credential, requestContext, cancellationToken); + } + + if (string.IsNullOrWhiteSpace(accessToken)) + { + throw new AuthenticationFailedException("Authentication failed. No access token could be obtained."); + } + + if (!string.IsNullOrEmpty(this.AccessTokenPath)) + { + using (FileSystemStream fileStream = this.fileSystem.FileStream.New( + this.AccessTokenPath, + FileMode.Create, + FileAccess.ReadWrite, + FileShare.ReadWrite)) + { + byte[] bytedata = Encoding.Default.GetBytes(accessToken); + fileStream.Write(bytedata, 0, bytedata.Length); + await fileStream.FlushAsync().ConfigureAwait(false); + this.Logger.LogTraceMessage($"Access token saved to file: {this.AccessTokenPath}"); + } + } + + Console.WriteLine("[Access Token]:"); + Console.WriteLine(accessToken); + } + } + + /// + /// Acquires an access token using interactive browser authentication. + /// + /// The interactive browser credential to use. + /// The request context containing the required scopes. + /// A token that can be used to cancel the operation. + /// The access token string. + protected virtual async Task AcquireInteractiveTokenAsync( + TokenCredential credential, + TokenRequestContext requestContext, + CancellationToken cancellationToken) + { + AccessToken response = await credential.GetTokenAsync(requestContext, cancellationToken); + return response.Token; + } + + /// + /// Acquires an access token using device-code authentication. + /// This is used as a fallback when interactive browser authentication is unavailable. + /// + /// The device code credential to use. + /// The request context containing the required scopes. + /// A token that can be used to cancel the operation. + /// The access token string. + protected virtual async Task AcquireDeviceCodeTokenAsync( + TokenCredential credential, + TokenRequestContext requestContext, + CancellationToken cancellationToken) + { + AccessToken response = await credential.GetTokenAsync(requestContext, cancellationToken); + return response.Token; + } + + /// + /// Creates the used to request an access token for the target Key Vault resource. + /// Uses the Key Vault resource scope: "{KeyVaultUri}/.default". + /// + /// The token request context containing the required scopes. + protected virtual TokenRequestContext GetTokenRequestContext() + { + string[] installerTenantResourceScopes = new string[] + { + new Uri(baseUri: this.KeyVaultUri, relativeUri: ".default").ToString(), + // Example of a specific scope: + // "api://56e7ee83-1cf6-4048-a664-c2a08955f825/user_impersonation" + }; + + return new TokenRequestContext(scopes: installerTenantResourceScopes); + } + } +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.IntegrationTests/KeyVaultManagerTests.cs b/src/VirtualClient/VirtualClient.IntegrationTests/KeyVaultManagerTests.cs deleted file mode 100644 index a9c1666faa..0000000000 --- a/src/VirtualClient/VirtualClient.IntegrationTests/KeyVaultManagerTests.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace VirtualClient -{ - internal class KeyVaultManagerTests - { - } -} diff --git a/src/VirtualClient/VirtualClient.Main/GetAccessTokenCommand.cs b/src/VirtualClient/VirtualClient.Main/GetAccessTokenCommand.cs new file mode 100644 index 0000000000..e984d0e7bd --- /dev/null +++ b/src/VirtualClient/VirtualClient.Main/GetAccessTokenCommand.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace VirtualClient +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using VirtualClient.Contracts; + + /// + /// Command that executes a profile to acquire an access token for an Azure Key Vault. + /// + internal class GetAccessTokenCommand : ExecuteProfileCommand + { + /// + /// Executes the access token acquisition operations using the configured profile. + /// + /// The arguments provided to the application on the command line. + /// Provides a token that can be used to cancel the command operations. + /// The exit code for the command operations. + public override Task ExecuteAsync(string[] args, CancellationTokenSource cancellationTokenSource) + { + this.Timeout = ProfileTiming.OneIteration(); + this.Profiles = new List + { + new DependencyProfileReference("GET-ACCESS-TOKEN.json") + }; + + if (this.Parameters == null) + { + this.Parameters = new Dictionary(StringComparer.OrdinalIgnoreCase); + } + + return base.ExecuteAsync(args, cancellationTokenSource); + } + } +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.Main/Program.cs b/src/VirtualClient/VirtualClient.Main/Program.cs index db50e303ed..43c452b112 100644 --- a/src/VirtualClient/VirtualClient.Main/Program.cs +++ b/src/VirtualClient/VirtualClient.Main/Program.cs @@ -325,6 +325,11 @@ internal static CommandLineBuilder SetupCommandLine(string[] args, CancellationT apiSubcommand.Handler = CommandHandler.Create(cmd => cmd.ExecuteAsync(args, cancellationTokenSource)); rootCommand.Add(apiSubcommand); + Command getAccessTokenSubcommand = Program.CreateGetTokenSubCommand(settings); + getAccessTokenSubcommand.TreatUnmatchedTokensAsErrors = true; + getAccessTokenSubcommand.Handler = CommandHandler.Create(cmd => cmd.ExecuteAsync(args, cancellationTokenSource)); + rootCommand.Add(getAccessTokenSubcommand); + Command bootstrapSubcommand = Program.CreateBootstrapSubcommand(settings); bootstrapSubcommand.TreatUnmatchedTokensAsErrors = true; bootstrapSubcommand.Handler = CommandHandler.Create(cmd => cmd.ExecuteAsync(args, cancellationTokenSource)); @@ -406,6 +411,36 @@ private static Command CreateApiSubcommand(DefaultSettings settings) return apiCommand; } + private static Command CreateGetTokenSubCommand(DefaultSettings settings) + { + Command getAccessTokenCommand = new Command( + "get-token", + "Get access token for current user to authenticate with Azure Key Vault.") + { + // OPTIONAL + // ------------------------------------------------------------------- + // --clean + OptionFactory.CreateCleanOption(required: false), + + // --client-id + OptionFactory.CreateClientIdOption(required: false, Guid.NewGuid().ToString()), + + // --experiment-id + OptionFactory.CreateExperimentIdOption(required: false, Guid.NewGuid().ToString()), + + // --key-vault + OptionFactory.CreateKeyVaultOption(required: false), + + // --parameters + OptionFactory.CreateParametersOption(required: false), + + // --verbose + OptionFactory.CreateVerboseFlag(required: false, false) + }; + + return getAccessTokenCommand; + } + private static Command CreateBootstrapSubcommand(DefaultSettings settings) { Command bootstrapCommand = new Command( diff --git a/src/VirtualClient/VirtualClient.Main/profiles/GET-ACCESS-TOKEN.json b/src/VirtualClient/VirtualClient.Main/profiles/GET-ACCESS-TOKEN.json new file mode 100644 index 0000000000..3e6d6f2c4e --- /dev/null +++ b/src/VirtualClient/VirtualClient.Main/profiles/GET-ACCESS-TOKEN.json @@ -0,0 +1,19 @@ +{ + "Description": "Get access token for the user that can be used to authenticate.", + "Parameters": { + "KeyVaultUri": null, + "TenantId": null, + "LogFileName": "AccessToken.txt" + }, + "Dependencies": [ + { + "Type": "KeyVaultAccessToken", + "Parameters": { + "Scenario": "GetKVAccessToken", + "TenantId": "$.Parameters.TenantId", + "KeyVaultUri": "$.Parameters.KeyVaultUri", + "LogFileName": "$.Parameters.LogFileName" + } + } + ] +} \ No newline at end of file diff --git a/src/VirtualClient/VirtualClient.UnitTests/CommandLineOptionTests.cs b/src/VirtualClient/VirtualClient.UnitTests/CommandLineOptionTests.cs index 8135448868..0cc22e5707 100644 --- a/src/VirtualClient/VirtualClient.UnitTests/CommandLineOptionTests.cs +++ b/src/VirtualClient/VirtualClient.UnitTests/CommandLineOptionTests.cs @@ -608,6 +608,45 @@ public void VirtualClientCommandLineSupportsResponseFiles() } } + [Test] + [TestCase("--agentId", "AgentID")] + [TestCase("--client-id", "AgentID")] + [TestCase("--c", "AgentID")] + [TestCase("--clean", null)] + [TestCase("--clean", "logs")] + [TestCase("--clean", "logs,packages,state,temp")] + [TestCase("--experimentId", "0B692DEB-411E-4AC1-80D5-AF539AE1D6B2")] + [TestCase("--experiment-id", "0B692DEB-411E-4AC1-80D5-AF539AE1D6B2")] + [TestCase("--e", "0B692DEB-411E-4AC1-80D5-AF539AE1D6B2")] + [TestCase("--kv", "https://anyvault.vault.azure.net/?cid=1...&tid=2")] + [TestCase("--key-vault", "testingKV")] + [TestCase("--parameters", "helloWorld=123,,,TenantId=789203498")] + [TestCase("--pm", "testing")] + [TestCase("--verbose", null)] + public void VirtualClientGetTokenCommandSupportsOnlyExpectedOptions(string option, string value) + { + using (CancellationTokenSource cancellationSource = new CancellationTokenSource()) + { + List arguments = new List() + { + "get-token" + }; + + arguments.Add(option); + if (value != null) + { + arguments.Add(value); + } + + Assert.DoesNotThrow(() => + { + ParseResult result = Program.SetupCommandLine(arguments.ToArray(), cancellationSource).Build().Parse(arguments); + Assert.IsFalse(result.Errors.Any()); + result.ThrowOnUsageError(); + }, $"Option '{option}' is not supported."); + } + } + private class TestExecuteCommand : ExecuteCommand { public Action OnExecuteCommand { get; set; }