Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose SqlAuthenticationParameters on SSPIContextProvider #2454

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.Data.Common;
Expand Down Expand Up @@ -317,7 +319,7 @@ internal struct SNI_Error
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down Expand Up @@ -398,7 +400,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
ConsumerInfo consumerInfo,
string constring,
ref IntPtr pConn,
byte[] spnBuffer,
ref string spn,
byte[] instanceName,
bool fOverrideCache,
bool fSync,
Expand Down Expand Up @@ -436,13 +438,59 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

if (spnBuffer != null)
if (spn != null)
{
fixed (byte* pin_spnBuffer = &spnBuffer[0])
// An empty string implies we need to find the SPN so we supply a buffer for the max size
if (spn.Length == 0)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
array.AsSpan().Clear();

try
{
fixed (byte* pin_spnBuffer = array)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;

var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);

if (result == 0)
{
spn = Encoding.Unicode.GetString(array).TrimEnd('\0');
}

return result;
}
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}

// We have a value of the SPN, so we marshal that and send it to the native layer
else
{
var writer = SqlObjectPools.BufferWriter.Rent();

// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
Encoding.Unicode.GetBytes(spn, writer);
Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");

try
{
fixed (byte* pin_spnBuffer = writer.WrittenSpan)
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)writer.WrittenCount;
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
}
}
finally
{
SqlObjectPools.BufferWriter.Return(writer);
}
}
}
else
Expand Down Expand Up @@ -471,23 +519,37 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, string serverUserName)
{
fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
sendLength = (uint)outBuff.Length;

var serverWriter = SqlObjectPools.BufferWriter.Rent();

try
{
Encoding.Unicode.GetBytes(serverUserName, serverWriter);

fixed (byte* pin_serverUserName = serverWriter.WrittenSpan)
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverWriter.WrittenCount,
null,
null);
}
}
finally
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
OutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverUserName.Length,
null,
null);
SqlObjectPools.BufferWriter.Return(serverWriter);
}
}

Expand Down
Expand Up @@ -503,6 +503,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
</Compile>
Expand Down
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -32,12 +33,12 @@ internal class SNIProxy
/// </summary>
/// <param name="sspiClientContextStatus">SSPI client context status</param>
/// <param name="receivedBuff">Receive buffer</param>
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <param name="sendWriter">Writer for send buffer</param>
/// <param name="serverNames">Service Principal Name</param>
/// <returns>SNI error code</returns>
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
{
// TODO: this should use ReadOnlyMemory all the way through
// TODO: this should use ReadOnlySpan all the way through
byte[] array = null;

if (!receivedBuff.IsEmpty)
Expand All @@ -46,10 +47,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
receivedBuff.CopyTo(array);
}

GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName);
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverNames);
}

private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, string[] serverNames)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -81,15 +82,10 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;

string[] serverSPNs = new string[serverName.Length];
for (int i = 0; i < serverName.Length; i++)
{
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
}
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
serverSPNs,
serverNames,
requestedContextFlags,
inSecurityBufferArray,
outSecurityBuffer,
Expand All @@ -103,10 +99,9 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
outSecurityBuffer.token = null;
}

sendBuff = outSecurityBuffer.token;
if (sendBuff == null)
if (outSecurityBuffer.token is { } token)
{
sendBuff = Array.Empty<byte>();
sendWriter.Write(token);
}

sspiClientContextStatus.SecurityContext = securityContext;
Expand Down Expand Up @@ -165,7 +160,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spnBuffer,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -229,12 +224,12 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
return sniHandle;
}

private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
return new[] { serverSPN };
}

string hostName = dataSource.ServerName;
Expand All @@ -252,7 +247,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}

private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -283,12 +278,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
return new[] { serverSpn, serverSpnWithDefaultPort };
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
return new[] { serverSpn };
}

/// <summary>
Expand Down
Expand Up @@ -134,7 +134,7 @@ internal static void Assert(string message)

private bool _is2022 = false;

private byte[][] _sniSpnBuffer = null;
private string[] _sniSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -404,7 +404,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
}
else
{
_sniSpnBuffer = null;
_sniSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -416,7 +416,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpnBuffer = null;
_sniSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -455,7 +455,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
false,
true,
fParallel,
Expand All @@ -468,8 +468,6 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -484,6 +482,8 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
Debug.Fail("SNI returned status != success, but no error thrown?");
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

_server = serverInfo.ResolvedServerName;

if (null != connHandler.PoolGroupProviderInfo)
Expand Down Expand Up @@ -554,7 +554,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
true,
true,
fParallel,
Expand All @@ -567,15 +567,15 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj)
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);

Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
Expand Down Expand Up @@ -8120,8 +8120,7 @@ internal int WriteSQLDNSCachingFeatureRequest(bool write /* if false just calcul
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -8289,8 +8288,8 @@ internal int WriteSQLDNSCachingFeatureRequest(bool write /* if false just calcul
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -8345,7 +8344,7 @@ internal int WriteSQLDNSCachingFeatureRequest(bool write /* if false just calcul

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down Expand Up @@ -12849,7 +12848,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
null == _sessionPool ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Expand Up @@ -186,7 +186,7 @@ private void ResetCancelAndProcessAttention()
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand Down