diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index fcadbdc152..4c8193519b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; +using System.Diagnostics; using System.Runtime.InteropServices; using System.Text; using Microsoft.Data.Common; @@ -398,7 +400,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, - byte[] spnBuffer, + ref string spn, byte[] instanceName, bool fOverrideCache, bool fSync, @@ -436,13 +438,59 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; - if (spnBuffer != null) + if (spn != null) { - fixed (byte* pin_spnBuffer = &spnBuffer[0]) + // An empty string implies we need to find the SPN so we supply a buffer for the max size + if (spn.Length == 0) { - clientConsumerInfo.szSPN = pin_spnBuffer; - clientConsumerInfo.cchSPN = (uint)spnBuffer.Length; - return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + var array = ArrayPool.Shared.Rent(SniMaxComposedSpnLength); + array.AsSpan().Clear(); + + try + { + fixed (byte* pin_spnBuffer = array) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength; + + var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + + if (result == 0) + { + spn = Encoding.Unicode.GetString(array).TrimEnd('\0'); + } + + return result; + } + } + finally + { + ArrayPool.Shared.Return(array); + } + } + + // We have a value of the SPN, so we marshal that and send it to the native layer + else + { + var writer = SqlObjectPools.BufferWriter.Rent(); + + // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. + Encoding.Unicode.GetBytes(spn, writer); + Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); + + try + { + fixed (byte* pin_spnBuffer = writer.WrittenSpan) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)writer.WrittenCount; + return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + } + } + finally + { + SqlObjectPools.BufferWriter.Return(writer); + } } } else @@ -471,26 +519,37 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int } } - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, out uint sendLength, byte[] serverUserName) + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, out uint sendLength, string serverUserName) { sendLength = (uint)outBuff.Length; - fixed (byte* pin_serverUserName = &serverUserName[0]) - fixed (byte* pInBuff = inBuff) - fixed (byte* pOutBuff = outBuff) + var serverWriter = SqlObjectPools.BufferWriter.Rent(); + + try + { + Encoding.Unicode.GetBytes(serverUserName, serverWriter); + + fixed (byte* pin_serverUserName = serverWriter.WrittenSpan) + fixed (byte* pInBuff = inBuff) + fixed (byte* pOutBuff = outBuff) + { + bool local_fDone; + return SNISecGenClientContextWrapper( + pConnectionObject, + pInBuff, + (uint)inBuff.Length, + pOutBuff, + ref sendLength, + out local_fDone, + pin_serverUserName, + (uint)serverWriter.WrittenCount, + null, + null); + } + } + finally { - bool local_fDone; - return SNISecGenClientContextWrapper( - pConnectionObject, - pInBuff, - (uint)inBuff.Length, - pOutBuff, - ref sendLength, - out local_fDone, - pin_serverUserName, - (uint)serverUserName.Length, - null, - null); + SqlObjectPools.BufferWriter.Return(serverWriter); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 54a5179e26..8536b19eaa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net; @@ -33,9 +34,9 @@ internal class SNIProxy /// SSPI client context status /// Receive buffer /// Writer for send buffer - /// Service Principal Name buffer + /// Service Principal Name buffer /// SNI error code - internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory receivedBuff, IBufferWriter sendWriter, byte[][] serverName) + internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory receivedBuff, IBufferWriter sendWriter, string[] serverNames) { // TODO: this should use ReadOnlyMemory all the way through byte[] array = null; @@ -46,10 +47,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont receivedBuff.CopyTo(array); } - GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName); + GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverNames); } - private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter sendWriter, byte[][] serverName) + private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter sendWriter, string[] serverSPNs) { SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext; ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags; @@ -81,11 +82,6 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte | ContextFlagsPal.Delegate | ContextFlagsPal.MutualAuth; - string[] serverSPNs = new string[serverName.Length]; - for (int i = 0; i < serverName.Length; i++) - { - serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]); - } SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext( credentialsHandle, ref securityContext, @@ -164,7 +160,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) string fullServerName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spnBuffer, string serverSPN, bool flushCache, bool async, @@ -228,12 +224,12 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) return sniHandle; } - private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN) + private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN) { Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName)); if (!string.IsNullOrWhiteSpace(serverSPN)) { - return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) }; + return new[] { serverSPN }; } string hostName = dataSource.ServerName; @@ -251,7 +247,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol); } - private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) + private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) { Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress)); IPHostEntry hostEntry = null; @@ -282,12 +278,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}"; // Set both SPNs with and without Port as Port is optional for default instance SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort); - return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) }; + return new[] { serverSpn, serverSpnWithDefaultPort }; } // else Named Pipes do not need to valid port SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn); - return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) }; + return new[] { serverSpn }; } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index b5fe52fff4..9632106ef3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -134,7 +134,7 @@ internal static void Assert(string message) private bool _is2022 = false; - private byte[][] _sniSpnBuffer = null; + private string[] _sniSpn = null; // SqlStatistics private SqlStatistics _statistics = null; @@ -404,7 +404,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) } else { - _sniSpnBuffer = null; + _sniSpn = null; SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID, authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString()); } @@ -416,7 +416,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) SqlClientEventSource.Log.TryTraceEvent(" Encryption will be disabled as target server is a SQL Local DB instance."); } - _sniSpnBuffer = null; + _sniSpn = null; _authenticationProvider = null; // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server @@ -455,7 +455,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpnBuffer, + ref _sniSpn, false, true, fParallel, @@ -468,7 +468,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -554,7 +554,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpnBuffer, + ref _sniSpn, true, true, fParallel, @@ -567,8 +567,6 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -576,6 +574,8 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) ThrowExceptionAndWarning(_physicalStateObj); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn); + uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); @@ -12850,7 +12850,7 @@ internal string TraceString() _fMARS ? bool.TrueString : bool.FalseString, null == _sessionPool ? "(null)" : _sessionPool.TraceString(), _is2005 ? bool.TrueString : bool.FalseString, - null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null), + null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 520367963d..021e982f30 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -186,7 +186,7 @@ private void ResetCancelAndProcessAttention() string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool fParallel, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 7d879ae4d5..7c8c0c5856 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -81,7 +81,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool parallel, @@ -94,7 +94,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 6f26250072..c078fd2992 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -143,7 +143,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool fParallel, @@ -156,22 +156,18 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) string hostNameInCertificate, string serverCertificateFilename) { - // We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer - spnBuffer = new byte[1][]; if (isIntegratedSecurity) { // now allocate proper length of buffer if (!string.IsNullOrEmpty(serverSPN)) { // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN); - Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); - spnBuffer[0] = srvSPN; SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN); } else { - spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength]; + // This will signal to the interop layer that we need to retrieve the SPN + serverSPN = string.Empty; } } @@ -179,8 +175,9 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) SQLDNSInfo cachedDNSInfo; bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName, + _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); + spn = new[] { serverSPN.TrimEnd() }; } protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index ba6c98aa82..6835dee8ff 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -1109,7 +1110,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, - byte[] spnBuffer, + ref string spn, byte[] instanceName, bool fOverrideCache, bool fSync, @@ -1161,13 +1162,59 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; - if (spnBuffer != null) + if (spn != null) { - fixed (byte* pin_spnBuffer = &spnBuffer[0]) + // An empty string implies we need to find the SPN so we supply a buffer for the max size + if (spn.Length == 0) { - clientConsumerInfo.szSPN = pin_spnBuffer; - clientConsumerInfo.cchSPN = (uint)spnBuffer.Length; - return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + var array = ArrayPool.Shared.Rent(SniMaxComposedSpnLength); + array.AsSpan().Clear(); + + try + { + fixed (byte* pin_spnBuffer = array) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength; + + var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + + if (result == 0) + { + spn = Encoding.Unicode.GetString(array).TrimEnd('\0'); + } + + return result; + } + } + finally + { + ArrayPool.Shared.Return(array); + } + } + + // We have a value of the SPN, so we marshal that and send it to the native layer + else + { + var writer = SqlObjectPools.BufferWriter.Rent(); + + // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. + Encoding.Unicode.GetBytes(spn, writer); + Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); + + try + { + fixed (byte* pin_spnBuffer = writer.WrittenSpan) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)writer.WrittenCount; + return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + } + } + finally + { + SqlObjectPools.BufferWriter.Return(writer); + } } } else @@ -1381,23 +1428,34 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int } } - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, out uint sendLength, byte[] serverUserName) + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, out uint sendLength, string serverUserName) { sendLength = (uint)outBuff.Length; - fixed (byte* pin_serverUserName = &serverUserName[0]) + var serverNameWriter = SqlObjectPools.BufferWriter.Rent(); + + try + { + Encoding.Unicode.GetBytes(serverUserName, serverNameWriter); + + fixed (byte* pin_serverUserName = serverNameWriter.WrittenSpan) + { + bool local_fDone; + return SNISecGenClientContextWrapper( + pConnectionObject, + inBuff, + outBuff, + ref sendLength, + out local_fDone, + pin_serverUserName, + (uint)serverNameWriter.WrittenCount, + null, + null); + } + } + finally { - bool local_fDone; - return SNISecGenClientContextWrapper( - pConnectionObject, - inBuff, - outBuff, - ref sendLength, - out local_fDone, - pin_serverUserName, - (uint)serverUserName.Length, - null, - null); + SqlObjectPools.BufferWriter.Return(serverNameWriter); } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 78605f14f5..4b550cf258 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -221,7 +221,7 @@ internal static void Assert(string message) private bool _is2022 = false; - private byte[] _sniSpnBuffer = null; + private string _sniSpn = null; // UNDONE - need to have some for both instances - both command and default??? @@ -543,27 +543,23 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); - if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { - // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN); - Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "The provided SPN length exceeded the buffer size."); - _sniSpnBuffer = srvSPN; + _sniSpn = serverInfo.ServerSPN; SqlClientEventSource.Log.TryTraceEvent(" Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN); } else { - // now allocate proper length of buffer - _sniSpnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength]; + _sniSpn = string.Empty; } + + _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); SqlClientEventSource.Log.TryTraceEvent(" SSPI or Active Directory Authentication Library for SQL Server based integrated authentication"); } else { _authenticationProvider = null; - _sniSpnBuffer = null; + _sniSpn = null; switch (authType) { @@ -642,7 +638,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) serverInfo.ExtendedServerName, timeout, out instanceName, - _sniSpnBuffer, + ref _sniSpn, false, true, fParallel, @@ -652,7 +648,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) FQDNforDNSCache, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -749,7 +745,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) serverInfo.ExtendedServerName, timeout, out instanceName, - _sniSpnBuffer, + ref _sniSpn, true, true, fParallel, @@ -759,8 +755,6 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) serverInfo.ResolvedServerName, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -769,6 +763,8 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) ThrowExceptionAndWarning(_physicalStateObj); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn); + UInt32 retCode = SNINativeMethodWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); SqlClientEventSource.Log.TryTraceEvent(" Sending prelogin handshake"); @@ -13706,7 +13702,7 @@ internal string TraceString() _is2000 ? bool.TrueString : bool.FalseString, _is2000SP1 ? bool.TrueString : bool.FalseString, _is2005 ? bool.TrueString : bool.FalseString, - null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null), + null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 8c8e2deea9..297a3f03b7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -256,7 +256,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) string serverName, TimeoutTimer timeout, out byte[] instanceName, - byte[] spnBuffer, + ref string spn, bool flushCache, bool async, bool fParallel, @@ -273,7 +273,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) _ = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, timeout.MillisecondsRemainingInt, + _sessionHandle = new SNIHandle(myInfo, serverName, ref spn, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs index 1012e666bf..5d65483f8f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs @@ -12,11 +12,11 @@ internal sealed class ManagedSSPIContextProvider : SSPIContextProvider { private SspiClientContextStatus? _sspiClientContextStatus; - protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer) + protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter) { _sspiClientContextStatus ??= new SspiClientContextStatus(); - SNIProxy.GenSspiClientContext(_sspiClientContextStatus, incomingBlob, outgoingBlobWriter, _sniSpnBuffer); + SNIProxy.GenSspiClientContext(_sspiClientContextStatus, incomingBlob, outgoingBlobWriter, new[] { AuthenticationParameters.ServerName }); SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _physicalStateObj.SessionId); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index 82ddb4ff16..3860a89987 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer) + protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; @@ -60,7 +60,7 @@ protected override void GenerateSspiClientContext(ReadOnlyMemory incomingB var outBuff = outgoingBlobWriter.GetSpan((int)s_maxSSPILength); - if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, incomingBlob.Span, outBuff, out var sendLength, _sniSpnBuffer[0])) + if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, incomingBlob.Span, outBuff, out var sendLength, AuthenticationParameters.ServerName)) { throw new InvalidOperationException(SQLMessage.SSPIGenerateError()); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs index 4554a5ffa6..b8ef33c574 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs @@ -1,9 +1,9 @@ #if NET7_0_OR_GREATER using System; -using System.Text; using System.Net.Security; using System.Buffers; +using System.Text; #nullable enable @@ -13,9 +13,9 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider { private NegotiateAuthentication? _negotiateAuth = null; - protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer) + protected override void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter) { - _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[0]) }); + _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = AuthenticationParameters.ServerName }); var result = _negotiateAuth.GetOutgoingBlob(incomingBlob.Span, out NegotiateAuthenticationStatusCode statusCode)!; SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}, StatusCode={1}", _physicalStateObj.SessionId, statusCode); if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index 48e16cb718..6f8483c2f7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -12,13 +12,34 @@ internal abstract class SSPIContextProvider private TdsParser _parser = null!; private ServerInfo _serverInfo = null!; private protected TdsParserStateObject _physicalStateObj = null!; + private SqlAuthenticationParameters? _parameters; - internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser) + internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser, string[] serverNames) + { + Debug.Assert(serverNames.Length > 0); + + Initialize(serverInfo, physicalStateObj, parser, serverNames[0]); + } + + internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser, string serverName) { _parser = parser; _physicalStateObj = physicalStateObj; _serverInfo = serverInfo; + var options = parser.Connection.ConnectionOptions; + + _parameters = new SqlAuthenticationParameters.Builder( + authenticationMethod: parser.Connection.ConnectionOptions.Authentication, + resource: serverName, + authority: null, + serverName: options.DataSource, + databaseName: options.InitialCatalog) + .WithConnectionId(parser.Connection.ClientConnectionId) + .WithConnectionTimeout(options.ConnectTimeout) + .WithUserId(options.UserID) + .WithPassword(options.Password); + Initialize(); } @@ -26,16 +47,18 @@ private protected virtual void Initialize() { } - protected abstract void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer); + /// + /// Gets the authentication parameters for the current connection. + /// + protected SqlAuthenticationParameters AuthenticationParameters => _parameters ?? throw new InvalidOperationException("SSPI context provider has not been initialized"); - internal void SSPIData(ReadOnlyMemory receivedBuff, IBufferWriter outgoingBlobWriter, byte[] sniSpnBuffer) - => SSPIData(receivedBuff, outgoingBlobWriter, new[] { sniSpnBuffer }); + protected abstract void GenerateSspiClientContext(ReadOnlyMemory incomingBlob, IBufferWriter outgoingBlobWriter); - internal void SSPIData(ReadOnlyMemory receivedBuff, IBufferWriter outgoingBlobWriter, byte[][] sniSpnBuffer) + internal void SSPIData(ReadOnlyMemory receivedBuff, IBufferWriter outgoingBlobWriter) { try { - GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, sniSpnBuffer); + GenerateSspiClientContext(receivedBuff, outgoingBlobWriter); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs index d5cf2398ec..4042934222 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs @@ -3,11 +3,54 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; +using System.Text; using System.Threading; namespace Microsoft.Data.SqlClient { + // This is a collection of general object pools that can be reused as needed. + internal static class SqlObjectPools + { + private static SqlObjectPool> _bufferWriter; + + internal static SqlObjectPool> BufferWriter + { + get + { + if (_bufferWriter is null) + { + Interlocked.CompareExchange(ref _bufferWriter, new(20, () => new(), a => a.Clear()), null); + } + + return _bufferWriter; + } + } + } + +#if NETSTANDARD || NETFRAMEWORK + internal static class BufferWriterExtensions + { + internal static long GetBytes(this Encoding encoding, string str, IBufferWriter bufferWriter) + { + var count = encoding.GetByteCount(str); + var array = ArrayPool.Shared.Rent(count); + + try + { + encoding.GetBytes(str, 0, str.Length, array, 0); + bufferWriter.Write(array); + return count; + } + finally + { + ArrayPool.Shared.Return(array); + } + } + } +#endif + // this is a very simple threadsafe pool derived from the aspnet/extensions default pool implementation // https://github.com/dotnet/extensions/blob/release/3.1/src/ObjectPool/src/DefaultObjectPool.cs internal sealed class SqlObjectPool where T : class diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index f6e34e1462..67766e06cd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -8,8 +8,6 @@ namespace Microsoft.Data.SqlClient { internal partial class TdsParser { - private static readonly SqlObjectPool> _writers = new(20, () => new(), a => a.Clear()); - internal void ProcessSSPI(int receivedLength) { Debug.Assert(_authenticationProvider is not null); @@ -28,15 +26,15 @@ internal void ProcessSSPI(int receivedLength) } // allocate send buffer and initialize length - var writer = _writers.Rent(); + var writer = SqlObjectPools.BufferWriter.Rent(); // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsMemory(0, receivedLength), writer, _sniSpnBuffer); + _authenticationProvider!.SSPIData(receivedBuff.AsMemory(0, receivedLength), writer); // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! _physicalStateObj.WriteByteSpan(writer.WrittenSpan); - _writers.Return(writer); + SqlObjectPools.BufferWriter.Return(writer); ArrayPool.Shared.Return(receivedBuff, clearArray: true); // set message type so server knows its a SSPI response @@ -182,14 +180,14 @@ internal void ProcessSSPI(int receivedLength) { if (rec.useSSPI) { - sspiWriter = _writers.Rent(); + sspiWriter = SqlObjectPools.BufferWriter.Rent(); // Call helper function for SSPI data and actual length. // Since we don't have SSPI data from the server, send null for the // byte[] buffer and 0 for the int length. Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(ReadOnlyMemory.Empty, sspiWriter, _sniSpnBuffer); + _authenticationProvider.SSPIData(ReadOnlyMemory.Empty, sspiWriter); _physicalStateObj.SniContext = SniContext.Snix_Login; @@ -222,7 +220,7 @@ internal void ProcessSSPI(int receivedLength) if (sspiWriter is { }) { - _writers.Return(sspiWriter); + SqlObjectPools.BufferWriter.Return(sspiWriter); } _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index 8d276cd285..902cd0cf25 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -149,7 +149,7 @@ internal sealed class SNIHandle : SafeHandle internal SNIHandle( SNINativeMethodWrapper.ConsumerInfo myInfo, string serverName, - byte[] spnBuffer, + ref string spn, int timeout, out byte[] instanceName, bool flushCache, @@ -185,11 +185,11 @@ internal sealed class SNIHandle : SafeHandle #if NETFRAMEWORK int transparentNetworkResolutionStateNo = (int)transparentNetworkResolutionState; _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, + ref spn, instanceName, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, ADP.IsAzureSqlServerEndpoint(serverName), ipPreference, cachedDNSInfo, hostNameInCertificate); #else _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); + ref spn, instanceName, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); #endif // NETFRAMEWORK } }