Skip to content

Commit

Permalink
Implement public client application global cache (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed Oct 23, 2020
1 parent 2134081 commit d65173b
Showing 1 changed file with 140 additions and 38 deletions.
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Security;
using System.Threading;
Expand All @@ -15,6 +16,14 @@ namespace Microsoft.Data.SqlClient
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/ActiveDirectoryAuthenticationProvider/*'/>
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
{
/// <summary>
/// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey.
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
Expand Down Expand Up @@ -67,10 +76,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
}

#if NETSTANDARD
private Func<object> parentActivityOrWindowFunc = null;
private Func<object> _parentActivityOrWindowFunc = null;

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetParentActivityOrWindowFunc/*'/>
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif

#if NETFRAMEWORK
Expand Down Expand Up @@ -108,51 +117,24 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
*
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
*/
string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
string redirectUri = s_nativeClientRedirectUri;
#if NETCOREAPP
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
redirectURI = "http://localhost";
}
#endif
IPublicClientApplication app;
#if NETSTANDARD
if (parentActivityOrWindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(parentActivityOrWindowFunc)
.Build();
redirectUri = "http://localhost";
}
#endif
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
, _iWin32WindowFunc
#endif
#if !NETCOREAPP
else
#if NETSTANDARD
, _parentActivityOrWindowFunc
#endif
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.Build();
}
);
IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
Expand Down Expand Up @@ -185,6 +167,7 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
IAccount account;
if (!string.IsNullOrEmpty(parameters.UserId))
Expand All @@ -200,17 +183,23 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
{
try
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
Expand Down Expand Up @@ -320,5 +309,118 @@ private class CustomWebUi : ICustomWebUi
public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}

private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
{
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
}
return clientApplicationInstance;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;

#if NETSTANDARD
if (_parentActivityOrWindowFunc != null)
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_parentActivityOrWindowFunc)
.Build();
}
#endif
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
#endif
#if !NETCOREAPP
else
#endif
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.Build();
}

return publicClientApplication;
}

internal class PublicClientAppKey
{
public readonly string _authority;
public readonly string _redirectUri;
public readonly string _applicationClientId;
#if NETFRAMEWORK
public readonly Func<System.Windows.Forms.IWin32Window> _iWin32WindowFunc;
#endif
#if NETSTANDARD
public readonly Func<object> _parentActivityOrWindowFunc;
#endif

public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
#if NETFRAMEWORK
, Func<System.Windows.Forms.IWin32Window> iWin32WindowFunc
#endif
#if NETSTANDARD
, Func<object> parentActivityOrWindowFunc
#endif
)
{
_authority = authority;
_redirectUri = redirectUri;
_applicationClientId = applicationClientId;
#if NETFRAMEWORK
_iWin32WindowFunc = iWin32WindowFunc;
#endif
#if NETSTANDARD
_parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif
}

public override bool Equals(object obj)
{
if (obj != null && obj is PublicClientAppKey pcaKey)
{
return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
&& string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
&& string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
#if NETFRAMEWORK
&& pcaKey._iWin32WindowFunc == _iWin32WindowFunc
#endif
#if NETSTANDARD
&& pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
#endif
);
}
return false;
}

public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId
#if NETFRAMEWORK
, _iWin32WindowFunc
#endif
#if NETSTANDARD
, _parentActivityOrWindowFunc
#endif
).GetHashCode();
}
}
}

0 comments on commit d65173b

Please sign in to comment.