Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement public client application global cache #770

Merged
merged 6 commits into from Oct 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Security;
using System.Threading;
Expand All @@ -15,6 +16,14 @@ namespace Microsoft.Data.SqlClient
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/ActiveDirectoryAuthenticationProvider/*'/>
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
{
/// <summary>
/// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey.
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
Expand Down Expand Up @@ -67,10 +76,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
}

#if NETSTANDARD
private Func<object> parentActivityOrWindowFunc = null;
private Func<object> _parentActivityOrWindowFunc = null;

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetParentActivityOrWindowFunc/*'/>
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif

#if NETFRAMEWORK
Expand Down Expand Up @@ -108,51 +117,24 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
*
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
*/
string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
string redirectUri = s_nativeClientRedirectUri;

#if NETCOREAPP
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
redirectURI = "http://localhost";
}
#endif
IPublicClientApplication app;

#if NETSTANDARD
if (parentActivityOrWindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(parentActivityOrWindowFunc)
.Build();
redirectUri = "http://localhost";
}
#endif
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
, _iWin32WindowFunc
#endif
#if !NETCOREAPP
else
#if NETSTANDARD
, _parentActivityOrWindowFunc
#endif
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.Build();
}
);

IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
Expand Down Expand Up @@ -185,6 +167,7 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
IAccount account;
if (!string.IsNullOrEmpty(parameters.UserId))
Expand All @@ -200,17 +183,23 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
{
try
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
Expand Down Expand Up @@ -320,5 +309,118 @@ private class CustomWebUi : ICustomWebUi
public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}

private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
{
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
}
return clientApplicationInstance;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;

#if NETSTANDARD
if (_parentActivityOrWindowFunc != null)
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_parentActivityOrWindowFunc)
.Build();
}
#endif
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
#endif
#if !NETCOREAPP
else
#endif
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.Build();
}

return publicClientApplication;
}

internal class PublicClientAppKey
{
public readonly string _authority;
public readonly string _redirectUri;
public readonly string _applicationClientId;
#if NETFRAMEWORK
public readonly Func<System.Windows.Forms.IWin32Window> _iWin32WindowFunc;
#endif
#if NETSTANDARD
public readonly Func<object> _parentActivityOrWindowFunc;
#endif

public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
#if NETFRAMEWORK
, Func<System.Windows.Forms.IWin32Window> iWin32WindowFunc
#endif
#if NETSTANDARD
, Func<object> parentActivityOrWindowFunc
#endif
)
{
_authority = authority;
_redirectUri = redirectUri;
_applicationClientId = applicationClientId;
#if NETFRAMEWORK
_iWin32WindowFunc = iWin32WindowFunc;
#endif
#if NETSTANDARD
_parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif
}

public override bool Equals(object obj)
{
if (obj != null && obj is PublicClientAppKey pcaKey)
{
return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
&& string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
&& string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
#if NETFRAMEWORK
&& pcaKey._iWin32WindowFunc == _iWin32WindowFunc
#endif
#if NETSTANDARD
&& pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
#endif
);
}
return false;
}

public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId
#if NETFRAMEWORK
, _iWin32WindowFunc
#endif
#if NETSTANDARD
, _parentActivityOrWindowFunc
#endif
).GetHashCode();
}
}
}