Skip to content

Commit

Permalink
Expose SqlAuthenticationParameters on SSPIContextProvider
Browse files Browse the repository at this point in the history
This change updates the SSPI context provider to surface information to implementers via SqlAuthenticationParameters.

As part of this change,the internal storage of SPN is changed from byte[] to string values. Majority of implementations need the string value anyway so it makes things simpler for book keeping.
  • Loading branch information
twsouthwick committed Apr 10, 2024
1 parent e52f1c3 commit 80e90cd
Show file tree
Hide file tree
Showing 16 changed files with 288 additions and 118 deletions.
Expand Up @@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.Data.Common;
Expand Down Expand Up @@ -398,7 +400,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
ConsumerInfo consumerInfo,
string constring,
ref IntPtr pConn,
byte[] spnBuffer,
ref string spn,
byte[] instanceName,
bool fOverrideCache,
bool fSync,
Expand Down Expand Up @@ -436,13 +438,59 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

if (spnBuffer != null)
if (spn != null)
{
fixed (byte* pin_spnBuffer = &spnBuffer[0])
// An empty string implies we need to find the SPN so we supply a buffer for the max size
if (spn.Length == 0)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
array.AsSpan().Clear();

try
{
fixed (byte* pin_spnBuffer = array)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;

var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);

if (result == 0)
{
spn = Encoding.Unicode.GetString(array).TrimEnd('\0');
}

return result;
}
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}

// We have a value of the SPN, so we marshal that and send it to the native layer
else
{
var writer = SqlObjectPools.BufferWriter.Rent();

// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
Encoding.Unicode.GetBytes(spn, writer);
Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");

try
{
fixed (byte* pin_spnBuffer = writer.WrittenSpan)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)writer.WrittenCount;
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
}
}
finally
{
SqlObjectPools.BufferWriter.Return(writer);
}
}
}
else
Expand Down Expand Up @@ -471,26 +519,37 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, string serverUserName)
{
sendLength = (uint)outBuff.Length;

fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
var serverWriter = SqlObjectPools.BufferWriter.Rent();

try
{
Encoding.Unicode.GetBytes(serverUserName, serverWriter);

fixed (byte* pin_serverUserName = serverWriter.WrittenSpan)
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverWriter.WrittenCount,
null,
null);
}
}
finally
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverUserName.Length,
null,
null);
SqlObjectPools.BufferWriter.Return(serverWriter);
}
}

Expand Down
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -33,9 +34,9 @@ internal class SNIProxy
/// <param name="sspiClientContextStatus">SSPI client context status</param>
/// <param name="receivedBuff">Receive buffer</param>
/// <param name="sendWriter">Writer for send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <param name="serverNames">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
{
// TODO: this should use ReadOnlyMemory all the way through
byte[] array = null;
Expand All @@ -46,10 +47,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
receivedBuff.CopyTo(array);
}

GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverNames);
}

private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, string[] serverSPNs)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -81,11 +82,6 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;

string[] serverSPNs = new string[serverName.Length];
for (int i = 0; i < serverName.Length; i++)
{
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
}
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
Expand Down Expand Up @@ -164,7 +160,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spnBuffer,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -228,12 +224,12 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
return sniHandle;
}

private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
return new[] { serverSPN };
}

string hostName = dataSource.ServerName;
Expand All @@ -251,7 +247,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}

private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -282,12 +278,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
return new[] { serverSpn, serverSpnWithDefaultPort };
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
return new[] { serverSpn };
}

/// <summary>
Expand Down
Expand Up @@ -134,7 +134,7 @@ internal static void Assert(string message)

private bool _is2022 = false;

private byte[][] _sniSpnBuffer = null;
private string[] _sniSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -404,7 +404,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
}
else
{
_sniSpnBuffer = null;
_sniSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -416,7 +416,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpnBuffer = null;
_sniSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -455,7 +455,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
false,
true,
fParallel,
Expand All @@ -468,7 +468,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
Expand Down Expand Up @@ -554,7 +554,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
true,
true,
fParallel,
Expand All @@ -567,15 +567,15 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);

Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
Expand Down Expand Up @@ -12850,7 +12850,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
null == _sessionPool ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Expand Up @@ -186,7 +186,7 @@ private void ResetCancelAndProcessAttention()
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Expand Up @@ -81,7 +81,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool parallel,
Expand All @@ -94,7 +94,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref
string hostNameInCertificate,
string serverCertificateFilename)
{
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
hostNameInCertificate, serverCertificateFilename);

Expand Down
Expand Up @@ -143,7 +143,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand All @@ -156,31 +156,28 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
string hostNameInCertificate,
string serverCertificateFilename)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
if (!string.IsNullOrEmpty(serverSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
// This will signal to the interop layer that we need to retrieve the SPN
serverSPN = string.Empty;
}
}

SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
spn = new[] { serverSPN.TrimEnd() };
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down

0 comments on commit 80e90cd

Please sign in to comment.