Skip to content

Commit

Permalink
Tests for Managed Identity - All pipelines updated
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed Sep 22, 2020
1 parent 2e815e4 commit 569470a
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 3 deletions.
Expand Up @@ -168,6 +168,7 @@ public override bool IsSupported(SqlAuthenticationMethod authentication)
}
}

#region IMDS Retry Helper
internal static class SqlManagedIdentityRetryHelper
{
internal const int DeltaBackOffInSeconds = 2;
Expand Down Expand Up @@ -246,4 +247,5 @@ internal static async Task<HttpResponseMessage> SendAsyncWithRetry(this HttpClie
}
}
}
#endregion
}
Expand Up @@ -3,6 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.IdentityModel.Clients.ActiveDirectory;

Expand All @@ -17,10 +21,186 @@ public static async Task<string> AzureActiveDirectoryAuthenticationCallback(stri
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
if (result == null)
{
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");
throw new Exception($"Failed to retrieve an access token for {resource}");
}

return result.AccessToken;
}

public static async Task<string> GetManagedIdentityToken(string objectId) =>
await new MockManagedIdentityTokenProvider().AcquireTokenAsync(objectId).ConfigureAwait(false);

}

#region Mock Managed Identity Token Provider
internal class MockManagedIdentityTokenProvider
{
// HttpClient is intended to be instantiated once and re-used throughout the life of an application.
#if NETFRAMEWORK
private static readonly HttpClient s_defaultHttpClient = new HttpClient();
#else
private static readonly HttpClient s_defaultHttpClient = new HttpClient(new HttpClientHandler() { CheckCertificateRevocationList = true });
#endif

private const string AzureVmImdsApiVersion = "&api-version=2018-02-01";
private const string AccessToken = "access_token";
private const string Resource = "https://database.windows.net";


private const int DefaultRetryTimeout = 0;
private const int DefaultMaxRetryCount = 5;

// Azure Instance Metadata Service (IMDS) endpoint
private const string AzureVmImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";

// Timeout for Azure IMDS probe request
internal const int AzureVmImdsProbeTimeoutInSeconds = 2;
internal readonly TimeSpan _azureVmImdsProbeTimeout = TimeSpan.FromSeconds(AzureVmImdsProbeTimeoutInSeconds);

// Configurable timeout for MSI retry logic
internal readonly int _retryTimeoutInSeconds = DefaultRetryTimeout;
internal readonly int _maxRetryCount = DefaultMaxRetryCount;

public async Task<string> AcquireTokenAsync(string objectId = null)
{
// Use the httpClient specified in the constructor. If it was not specified in the constructor, use the default httpClient.
HttpClient httpClient = s_defaultHttpClient;

try
{
// If user assigned managed identity is specified, include object ID parameter in request
string objectIdParameter = objectId != default
? $"&object_id={objectId}"
: string.Empty;

// Craft request as per the MSI protocol
var requestUrl = $"{AzureVmImdsEndpoint}?resource={Resource}{objectIdParameter}{AzureVmImdsApiVersion}";

HttpResponseMessage response = null;

try
{
response = await httpClient.SendAsyncWithRetry(getRequestMessage, _retryTimeoutInSeconds, _maxRetryCount, default).ConfigureAwait(false);
HttpRequestMessage getRequestMessage()
{
HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, requestUrl);
request.Headers.Add("Metadata", "true");
return request;
}
}
catch (HttpRequestException)
{
// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
return null;
}

// If the response is successful, it should have JSON response with an access_token field
if (response.IsSuccessStatusCode)
{
string jsonResponse = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
int accessTokenStartIndex = jsonResponse.IndexOf(AccessToken) + AccessToken.Length + 3;
return jsonResponse.Substring(accessTokenStartIndex, jsonResponse.IndexOf('"', accessTokenStartIndex) - accessTokenStartIndex);
}

// RetryFailure : Failed after 5 retries.
// NonRetryableError : Received a non-retryable error.
string errorStatusDetail = response.IsRetryableStatusCode()
? "Failed after 5 retries"
: "Received a non-retryable error.";

string errorText = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
return null;
}
catch (Exception)
{
// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
return null;
}
}
}

#region IMDS Retry Helper
internal static class SqlManagedIdentityRetryHelper
{
internal const int DeltaBackOffInSeconds = 2;
internal const string RetryTimeoutError = "Reached retry timeout limit set by MsiRetryTimeout parameter in connection string.";

// for unit test purposes
internal static bool s_waitBeforeRetry = true;

internal static bool IsRetryableStatusCode(this HttpResponseMessage response)
{
// 404 NotFound, 429 TooManyRequests, and 5XX server error status codes are retryable
return Regex.IsMatch(((int)response.StatusCode).ToString(), @"404|429|5\d{2}");
}

/// <summary>
/// Implements recommended retry guidance here: https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
/// </summary>
internal static async Task<HttpResponseMessage> SendAsyncWithRetry(this HttpClient httpClient, Func<HttpRequestMessage> getRequest, int retryTimeoutInSeconds, int maxRetryCount, CancellationToken cancellationToken)
{
using (var timeoutTokenSource = new CancellationTokenSource())
using (var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken))
{
try
{
// if retry timeout is configured, configure cancellation after timeout period elapses
if (retryTimeoutInSeconds > 0)
{
timeoutTokenSource.CancelAfter(TimeSpan.FromSeconds(retryTimeoutInSeconds));
}

var attempts = 0;
var backoffTimeInSecs = 0;
HttpResponseMessage response;

while (true)
{
attempts++;

try
{
response = await httpClient.SendAsync(getRequest(), linkedTokenSource.Token).ConfigureAwait(false);

if (response.IsSuccessStatusCode || !response.IsRetryableStatusCode() || attempts == maxRetryCount)
{
break;
}
}
catch (HttpRequestException)
{
if (attempts == maxRetryCount)
{
throw;
}
}

if (s_waitBeforeRetry)
{
// use recommended exponential backoff strategy, and use linked token wait handle so caller or retry timeout is still able to cancel
backoffTimeInSecs += (int)Math.Pow(DeltaBackOffInSeconds, attempts);
linkedTokenSource.Token.WaitHandle.WaitOne(TimeSpan.FromSeconds(backoffTimeInSecs));
linkedTokenSource.Token.ThrowIfCancellationRequested();
}
}

return response;
}
catch (OperationCanceledException)
{
if (timeoutTokenSource.IsCancellationRequested)
{
throw new TimeoutException(RetryTimeoutError);
}

throw;
}
}
}
}
#endregion
#endregion
}

Expand Up @@ -47,12 +47,16 @@ public static class DataTestUtility
public static readonly string DNSCachingServerTR = null; // this is for the tenant ring
public static readonly bool IsDNSCachingSupportedCR = false; // this is for the control ring
public static readonly bool IsDNSCachingSupportedTR = false; // this is for the tenant ring
public static readonly string UserManagedIdentityObjectId = null;

public static readonly string EnclaveAzureDatabaseConnString = null;

public static bool ManagedIdentity = true;
public static string AADAccessToken = null;
public static string AADSystemIdentityAccessToken = null;
public static string AADUserIdentityAccessToken = null;
public const string UdtTestDbName = "UdtTestDb";
public const string AKVKeyName = "TestSqlClientAzureKeyVaultProvider";

private const string ManagedNetworkingAppContextSwitch = "Switch.Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows";

private static Dictionary<string, bool> AvailableDatabases;
Expand Down Expand Up @@ -83,6 +87,7 @@ static DataTestUtility()
IsDNSCachingSupportedCR = c.IsDNSCachingSupportedCR;
IsDNSCachingSupportedTR = c.IsDNSCachingSupportedTR;
EnclaveAzureDatabaseConnString = c.EnclaveAzureDatabaseConnString;
UserManagedIdentityObjectId = c.UserManagedIdentityObjectId;

System.Net.ServicePointManager.SecurityProtocol |= System.Net.SecurityProtocolType.Tls12;

Expand Down Expand Up @@ -403,8 +408,39 @@ public static string GetAccessToken()
return (null != AADAccessToken) ? new string(AADAccessToken.ToCharArray()) : null;
}

public static string GetSystemIdentityAccessToken()
{
if (true == ManagedIdentity && null == AADSystemIdentityAccessToken && IsAADPasswordConnStrSetup())
{
AADSystemIdentityAccessToken = AADUtility.GetManagedIdentityToken(null).GetAwaiter().GetResult();
if (AADSystemIdentityAccessToken == null)
{
ManagedIdentity = false;
}
}
return (null != AADSystemIdentityAccessToken) ? new string(AADSystemIdentityAccessToken.ToCharArray()) : null;
}

public static string GetUserIdentityAccessToken()
{
if (true == ManagedIdentity && null == AADUserIdentityAccessToken && IsAADPasswordConnStrSetup())
{
// Pass User Assigned Managed Identity Object Id here.
AADUserIdentityAccessToken = AADUtility.GetManagedIdentityToken(UserManagedIdentityObjectId).GetAwaiter().GetResult();
if (AADSystemIdentityAccessToken == null)
{
ManagedIdentity = false;
}
}
return (null != AADUserIdentityAccessToken) ? new string(AADUserIdentityAccessToken.ToCharArray()) : null;
}

public static bool IsAccessTokenSetup() => !string.IsNullOrEmpty(GetAccessToken());

public static bool IsSystemIdentityTokenSetup() => !string.IsNullOrEmpty(GetSystemIdentityAccessToken());

public static bool IsUserIdentityTokenSetup() => !string.IsNullOrEmpty(GetUserIdentityAccessToken());

public static bool IsFileStreamSetup() => SupportsFileStream;

private static bool CheckException<TException>(Exception ex, string exceptionMessage, bool innerExceptionMustBeNull) where TException : Exception
Expand Down

0 comments on commit 569470a

Please sign in to comment.