Skip to content

Commit

Permalink
Feature | Introduce "Active Directory Default" authentication mode (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed May 17, 2021
1 parent 5e067c4 commit 4316474
Show file tree
Hide file tree
Showing 38 changed files with 372 additions and 106 deletions.
Expand Up @@ -41,5 +41,9 @@
<summary>Alias for "Active Directory Managed Identity" authentication method. Use System Assigned or User Assigned Managed Identity to connect to SQL Database from Azure client environments that have enabled support for Managed Identity. For User Assigned Managed Identity, 'User Id' or 'UID' is required to be set to the "client ID" of the user identity.</summary>
<value>8</value>
</ActiveDirectoryMSI>
<ActiveDirectoryDefault>
<summary>The authentication method uses Active Directory Default. Use this mode to connect to a SQL Database using multiple non-interactive authentication methods tried sequentially to acquire an access token. This method does not fallback to the "Active Directory Interactive" authentication method.</summary>
<value>9</value>
</ActiveDirectoryDefault>
</members>
</docs>
Expand Up @@ -99,6 +99,8 @@ public enum SqlAuthenticationMethod
ActiveDirectoryManagedIdentity = 7,
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationMethod.xml' path='docs/members[@name="SqlAuthenticationMethod"]/ActiveDirectoryMSI/*'/>
ActiveDirectoryMSI = 8,
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationMethod.xml' path='docs/members[@name="SqlAuthenticationMethod"]/ActiveDirectoryDefault/*'/>
ActiveDirectoryDefault = 9,
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationMethod.xml' path='docs/members[@name="SqlAuthenticationMethod"]/NotSpecified/*'/>
NotSpecified = 0,
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationMethod.xml' path='docs/members[@name="SqlAuthenticationMethod"]/SqlPassword/*'/>
Expand Down
Expand Up @@ -99,6 +99,7 @@ internal static string ConvertToString(object value)

private const string ApplicationIntentReadWriteString = "ReadWrite";
private const string ApplicationIntentReadOnlyString = "ReadOnly";

const string SqlPasswordString = "Sql Password";
const string ActiveDirectoryPasswordString = "Active Directory Password";
const string ActiveDirectoryIntegratedString = "Active Directory Integrated";
Expand All @@ -107,13 +108,48 @@ internal static string ConvertToString(object value)
const string ActiveDirectoryDeviceCodeFlowString = "Active Directory Device Code Flow";
internal const string ActiveDirectoryManagedIdentityString = "Active Directory Managed Identity";
internal const string ActiveDirectoryMSIString = "Active Directory MSI";
internal const string ActiveDirectoryDefaultString = "Active Directory Default";

internal static bool TryConvertToAuthenticationType(string value, out SqlAuthenticationMethod result)
#if DEBUG
private static string[] s_supportedAuthenticationModes =
{
"NotSpecified",
"SqlPassword",
"ActiveDirectoryPassword",
"ActiveDirectoryIntegrated",
"ActiveDirectoryInteractive",
"ActiveDirectoryServicePrincipal",
"ActiveDirectoryDeviceCodeFlow",
"ActiveDirectoryManagedIdentity",
"ActiveDirectoryMSI",
"ActiveDirectoryDefault"
};

private static bool IsValidAuthenticationMethodEnum()
{
Debug.Assert(Enum.GetNames(typeof(SqlAuthenticationMethod)).Length == 9, "SqlAuthenticationMethod enum has changed, update needed");
string[] names = Enum.GetNames(typeof(SqlAuthenticationMethod));
int l = s_supportedAuthenticationModes.Length;
bool listValid;
if (listValid = names.Length == l)
{
for (int i = 0; i < l; i++)
{
if (s_supportedAuthenticationModes[i].CompareTo(names[i]) != 0)
{
listValid = false;
}
}
}
return listValid;
}
#endif

internal static bool TryConvertToAuthenticationType(string value, out SqlAuthenticationMethod result)
{
#if DEBUG
Debug.Assert(IsValidAuthenticationMethodEnum(), "SqlAuthenticationMethod enum has changed, update needed");
#endif
bool isSuccess = false;

if (StringComparer.InvariantCultureIgnoreCase.Equals(value, SqlPasswordString)
|| StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.SqlPassword, CultureInfo.InvariantCulture)))
{
Expand Down Expand Up @@ -162,6 +198,12 @@ internal static bool TryConvertToAuthenticationType(string value, out SqlAuthent
result = SqlAuthenticationMethod.ActiveDirectoryMSI;
isSuccess = true;
}
else if (StringComparer.InvariantCultureIgnoreCase.Equals(value, ActiveDirectoryDefaultString)
|| StringComparer.InvariantCultureIgnoreCase.Equals(value, Convert.ToString(SqlAuthenticationMethod.ActiveDirectoryDefault, CultureInfo.InvariantCulture)))
{
result = SqlAuthenticationMethod.ActiveDirectoryDefault;
isSuccess = true;
}
else
{
result = DbConnectionStringDefaults.Authentication;
Expand Down Expand Up @@ -606,7 +648,7 @@ internal static ApplicationIntent ConvertToApplicationIntent(string keyword, obj

internal static bool IsValidAuthenticationTypeValue(SqlAuthenticationMethod value)
{
Debug.Assert(Enum.GetNames(typeof(SqlAuthenticationMethod)).Length == 9, "SqlAuthenticationMethod enum has changed, update needed");
Debug.Assert(Enum.GetNames(typeof(SqlAuthenticationMethod)).Length == 10, "SqlAuthenticationMethod enum has changed, update needed");
return value == SqlAuthenticationMethod.SqlPassword
|| value == SqlAuthenticationMethod.ActiveDirectoryPassword
|| value == SqlAuthenticationMethod.ActiveDirectoryIntegrated
Expand All @@ -615,6 +657,7 @@ internal static bool IsValidAuthenticationTypeValue(SqlAuthenticationMethod valu
|| value == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
|| value == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
|| value == SqlAuthenticationMethod.ActiveDirectoryMSI
|| value == SqlAuthenticationMethod.ActiveDirectoryDefault
|| value == SqlAuthenticationMethod.NotSpecified;
}

Expand All @@ -640,6 +683,8 @@ internal static string AuthenticationTypeToString(SqlAuthenticationMethod value)
return ActiveDirectoryManagedIdentityString;
case SqlAuthenticationMethod.ActiveDirectoryMSI:
return ActiveDirectoryMSIString;
case SqlAuthenticationMethod.ActiveDirectoryDefault:
return ActiveDirectoryDefaultString;
default:
return null;
}
Expand Down
Expand Up @@ -152,6 +152,8 @@ private static SqlAuthenticationMethod AuthenticationEnumFromString(string authe
return SqlAuthenticationMethod.ActiveDirectoryManagedIdentity;
case ActiveDirectoryMSI:
return SqlAuthenticationMethod.ActiveDirectoryMSI;
case ActiveDirectoryDefault:
return SqlAuthenticationMethod.ActiveDirectoryDefault;
default:
throw SQL.UnsupportedAuthentication(authentication);
}
Expand Down
Expand Up @@ -21,6 +21,7 @@ internal partial class SqlAuthenticationProviderManager
private const string ActiveDirectoryDeviceCodeFlow = "active directory device code flow";
private const string ActiveDirectoryManagedIdentity = "active directory managed identity";
private const string ActiveDirectoryMSI = "active directory msi";
private const string ActiveDirectoryDefault = "active directory default";

private readonly string _typeName;
private readonly IReadOnlyCollection<SqlAuthenticationMethod> _authenticationsWithAppSpecifiedProvider;
Expand All @@ -46,6 +47,7 @@ private static void SetDefaultAuthProviders(SqlAuthenticationProviderManager ins
instance.SetProvider(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow, activeDirectoryAuthProvider);
instance.SetProvider(SqlAuthenticationMethod.ActiveDirectoryManagedIdentity, activeDirectoryAuthProvider);
instance.SetProvider(SqlAuthenticationMethod.ActiveDirectoryMSI, activeDirectoryAuthProvider);
instance.SetProvider(SqlAuthenticationMethod.ActiveDirectoryDefault, activeDirectoryAuthProvider);
}
}
/// <summary>
Expand Down
Expand Up @@ -193,11 +193,15 @@ public SqlConnection(string connectionString, SqlCredential credential) : this()
}
else if (UsesActiveDirectoryManagedIdentity(connectionOptions))
{
throw SQL.SettingCredentialWithManagedIdentityArgument(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
throw SQL.SettingCredentialWithNonInteractiveArgument(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
}
else if (UsesActiveDirectoryMSI(connectionOptions))
{
throw SQL.SettingCredentialWithManagedIdentityArgument(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
throw SQL.SettingCredentialWithNonInteractiveArgument(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
}
else if (UsesActiveDirectoryDefault(connectionOptions))
{
throw SQL.SettingCredentialWithNonInteractiveArgument(DbConnectionStringBuilderUtil.ActiveDirectoryDefaultString);
}

Credential = credential;
Expand Down Expand Up @@ -508,6 +512,11 @@ private bool UsesActiveDirectoryMSI(SqlConnectionString opt)
return opt != null && opt.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI;
}

private bool UsesActiveDirectoryDefault(SqlConnectionString opt)
{
return opt != null && opt.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault;
}

private bool UsesAuthentication(SqlConnectionString opt)
{
return opt != null && opt.Authentication != SqlAuthenticationMethod.NotSpecified;
Expand Down Expand Up @@ -565,7 +574,7 @@ public override string ConnectionString
if (_credential != null)
{
// Check for Credential being used with Authentication=ActiveDirectoryIntegrated | ActiveDirectoryInteractive |
// ActiveDirectoryDeviceCodeFlow | ActiveDirectoryManagedIdentity/ActiveDirectoryMSI. Since a different error string is used
// ActiveDirectoryDeviceCodeFlow | ActiveDirectoryManagedIdentity/ActiveDirectoryMSI | ActiveDirectoryDefault. Since a different error string is used
// for this case in ConnectionString setter vs in Credential setter, check for this error case before calling
// CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential, which is common to both setters.
if (UsesActiveDirectoryIntegrated(connectionOptions))
Expand All @@ -582,11 +591,15 @@ public override string ConnectionString
}
else if (UsesActiveDirectoryManagedIdentity(connectionOptions))
{
throw SQL.SettingManagedIdentityWithCredential(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
throw SQL.SettingNonInteractiveWithCredential(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
}
else if (UsesActiveDirectoryMSI(connectionOptions))
{
throw SQL.SettingManagedIdentityWithCredential(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
throw SQL.SettingNonInteractiveWithCredential(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
}
else if (UsesActiveDirectoryDefault(connectionOptions))
{
throw SQL.SettingNonInteractiveWithCredential(DbConnectionStringBuilderUtil.ActiveDirectoryDefaultString);
}

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
Expand Down Expand Up @@ -878,7 +891,7 @@ public SqlCredential Credential
{
var connectionOptions = (SqlConnectionString)ConnectionOptions;
// Check for Credential being used with Authentication=ActiveDirectoryIntegrated | ActiveDirectoryInteractive |
// ActiveDirectoryDeviceCodeFlow | ActiveDirectoryManagedIdentity/ActiveDirectoryMSI. Since a different error string is used
// ActiveDirectoryDeviceCodeFlow | ActiveDirectoryManagedIdentity/ActiveDirectoryMSI | ActiveDirectoryDefault. Since a different error string is used
// for this case in ConnectionString setter vs in Credential setter, check for this error case before calling
// CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential, which is common to both setters.
if (UsesActiveDirectoryIntegrated(connectionOptions))
Expand All @@ -895,11 +908,15 @@ public SqlCredential Credential
}
else if (UsesActiveDirectoryManagedIdentity(connectionOptions))
{
throw SQL.SettingCredentialWithManagedIdentityInvalid(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
throw SQL.SettingCredentialWithNonInteractiveInvalid(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
}
else if (UsesActiveDirectoryMSI(connectionOptions))
{
throw SQL.SettingCredentialWithManagedIdentityInvalid(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
throw SQL.SettingCredentialWithNonInteractiveInvalid(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
}
else if (UsesActiveDirectoryDefault(connectionOptions))
{
throw SQL.SettingCredentialWithNonInteractiveInvalid(DbConnectionStringBuilderUtil.ActiveDirectoryDefaultString);
}

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
Expand Down
Expand Up @@ -484,12 +484,17 @@ internal SqlConnectionString(string connectionString) : base(connectionString, G

if (Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity && HasPasswordKeyword)
{
throw SQL.ManagedIdentityWithPassword(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
throw SQL.NonInteractiveWithPassword(DbConnectionStringBuilderUtil.ActiveDirectoryManagedIdentityString);
}

if (Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI && HasPasswordKeyword)
{
throw SQL.ManagedIdentityWithPassword(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
throw SQL.NonInteractiveWithPassword(DbConnectionStringBuilderUtil.ActiveDirectoryMSIString);
}

if (Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault && HasPasswordKeyword)
{
throw SQL.NonInteractiveWithPassword(DbConnectionStringBuilderUtil.ActiveDirectoryDefaultString);
}
}

Expand Down
Expand Up @@ -1308,6 +1308,7 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword,
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault
// Since AD Integrated may be acting like Windows integrated, additionally check _fedAuthRequired
|| (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired))
{
Expand Down Expand Up @@ -2116,6 +2117,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault
|| (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired),
"Credentials aren't provided for calling MSAL");
Debug.Assert(fedAuthInfo != null, "info should not be null.");
Expand Down Expand Up @@ -2358,6 +2360,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow:
case SqlAuthenticationMethod.ActiveDirectoryManagedIdentity:
case SqlAuthenticationMethod.ActiveDirectoryMSI:
case SqlAuthenticationMethod.ActiveDirectoryDefault:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
Expand Down

0 comments on commit 4316474

Please sign in to comment.