diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index c8591a8c11..e680f2853a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -387,9 +387,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } @@ -399,7 +399,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string constring, ref IntPtr pConn, byte[] spnBuffer, - byte[] instanceName, + Span instanceName, bool fOverrideCache, bool fSync, int timeout, @@ -409,7 +409,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string hostNameInCertificate) { - fixed (byte* pin_instanceName = &instanceName[0]) + fixed (byte* pin_instanceName = instanceName) { SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO(); @@ -432,9 +432,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == null ? null : cachedDNSInfo.Port.ToString(); if (spnBuffer != null) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index 964d332ae4..d8ab44f5ce 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -13,6 +13,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.ProviderBase; +using System.Net.Sockets; namespace Microsoft.Data.SqlClient.SNI { @@ -45,55 +46,50 @@ internal enum SNIProviders /// /// SMUX packet header /// - internal sealed class SNISMUXHeader + internal struct SNISMUXHeader { public const int HEADER_LENGTH = 16; - public byte SMID; - public byte flags; - public ushort sessionId; - public uint length; - public uint sequenceNumber; - public uint highwater; + public byte Flags; + public ushort SessionId; + public uint Length; + public uint SequenceNumber; + public uint Highwater; - public void Read(byte[] bytes) + public SNISMUXHeader(byte flags, ushort sessionId, uint length, uint sequenceNumber, uint highwater) { - SMID = bytes[0]; - flags = bytes[1]; - Span span = bytes.AsSpan(); - sessionId = BinaryPrimitives.ReadUInt16LittleEndian(span.Slice(2)); - length = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(4)) - SNISMUXHeader.HEADER_LENGTH; - sequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(8)); - highwater = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(12)); + Flags = flags; + SessionId = sessionId; + Length = length; + SequenceNumber = sequenceNumber; + Highwater = highwater; } - public void Write(Span bytes) + public void Read(Span bytes) + { + // As per the MC-SMP spec, the first byte of the header will always be 0x53 + Debug.Assert(bytes[0] == 0x53, "First byte of the SNI SMUX header was not 0x53"); + + Flags = bytes[1]; + SessionId = BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(2)); + Length = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(4)) - SNISMUXHeader.HEADER_LENGTH; + SequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(8)); + Highwater = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(12)); + } + + public readonly void Write(Span bytes) { - uint value = highwater; // access the highest element first to cause the largest range check in the jit, then fill in the rest of the value and carry on as normal - bytes[15] = (byte)((value >> 24) & 0xff); - bytes[12] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.highwater).CopyTo(headerBytes, 12); - bytes[13] = (byte)((value >> 8) & 0xff); - bytes[14] = (byte)((value >> 16) & 0xff); - - bytes[0] = SMID; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0); - bytes[1] = flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1); - - value = sessionId; - bytes[2] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sessionId).CopyTo(headerBytes, 2); - bytes[3] = (byte)((value >> 8) & 0xff); - - value = length; - bytes[4] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.length).CopyTo(headerBytes, 4); - bytes[5] = (byte)((value >> 8) & 0xff); - bytes[6] = (byte)((value >> 16) & 0xff); - bytes[7] = (byte)((value >> 24) & 0xff); - - value = sequenceNumber; - bytes[8] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sequenceNumber).CopyTo(headerBytes, 8); - bytes[9] = (byte)((value >> 8) & 0xff); - bytes[10] = (byte)((value >> 16) & 0xff); - bytes[11] = (byte)((value >> 24) & 0xff); + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(12), Highwater); + + bytes[0] = 0x53; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0); + bytes[1] = Flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1); + + BinaryPrimitives.WriteUInt16LittleEndian(bytes.Slice(2), SessionId); + + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(4), Length); + + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(8), SequenceNumber); } } @@ -101,8 +97,7 @@ public void Write(Span bytes) /// /// SMUX packet flags /// - [Flags] - internal enum SNISMUXFlags + internal enum SNISMUXFlags : byte { SMUX_SYN = 1, // Begin SMUX connection SMUX_ACK = 2, // Acknowledge SMUX packets @@ -332,7 +327,15 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5 } } - internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout) + /// + /// Returns array of IP addresses for the given server name, sorted according to the given preference. + /// + /// Thrown when ipPreference is not supported +#if NET6_0_OR_GREATER + internal static async ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async) +#else + internal static ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async) +#endif { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) { @@ -342,20 +345,92 @@ internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer ti args0: serverName, args1: remainingTimeout); using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); + +#if NET6_0_OR_GREATER + Task task = Dns.GetHostAddressesAsync(serverName, cts.Token); + + if (async) + { + return SortIpAddressesByPreference(await task.ConfigureAwait(false), ipPreference); + } + else + { + task.Wait(); + return SortIpAddressesByPreference(task.Result, ipPreference); + } +#else // using this overload to support netstandard Task task = Dns.GetHostAddressesAsync(serverName); - task.ConfigureAwait(false); + task.Wait(cts.Token); - return task.Result; + return new ValueTask(SortIpAddressesByPreference(task.Result, ipPreference)); +#endif } } - internal static IPAddress[] GetDnsIpAddresses(string serverName) + /// + /// Returns array of IP addresses for the given server name, sorted according to the given preference. + /// + /// Thrown when ipPreference is not supported + internal static async ValueTask GetDnsIpAddresses(string serverName, SqlConnectionIPAddressPreference ipPreference, bool async) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName); - return Dns.GetHostAddresses(serverName); + + return SortIpAddressesByPreference(async + ? await Dns.GetHostAddressesAsync(serverName) + : Dns.GetHostAddresses(serverName), + ipPreference); + } + } + + private static IPAddress[] SortIpAddressesByPreference(IPAddress[] dnsIPAddresses, SqlConnectionIPAddressPreference ipPreference) + { + AddressFamily? prioritiesFamily = ipPreference switch + { + SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, + SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6, + SqlConnectionIPAddressPreference.UsePlatformDefault => null, + _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(SortIpAddressesByPreference)) + }; + + if (prioritiesFamily == null) + { + return dnsIPAddresses; + } + else + { + int resultArrayIndex = 0; + IPAddress[] ipAddresses = new IPAddress[dnsIPAddresses.Length]; + + // Return addresses of the preferred family first + for (int i = 0; i < dnsIPAddresses.Length; i++) + { + if (dnsIPAddresses[i].AddressFamily == prioritiesFamily) + { + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; + } + } + + // Return addresses of the other family + for (int i = 0; i < dnsIPAddresses.Length; i++) + { + if (dnsIPAddresses[i].AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 + && dnsIPAddresses[i].AddressFamily != prioritiesFamily) + { + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; + } + } + + // If the DNS resolution returned records of types other than A and AAAA, the original array size will be + // too large, and must thus be resized. This is very unlikely, so we only try to do this post-hoc. + if (resultArrayIndex + 1 < ipAddresses.Length) + { + Array.Resize(ref ipAddresses, resultArrayIndex + 1); + } + + return ipAddresses; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index f929a1ba32..4e30c0d687 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -17,7 +17,7 @@ internal class SNIMarsConnection private readonly Guid _connectionId; private readonly Dictionary _sessions; private readonly byte[] _headerBytes; - private readonly SNISMUXHeader _currentHeader; + private SNISMUXHeader _currentHeader; private readonly object _sync; private SNIHandle _lowerHandle; private ushort _nextSessionId; @@ -53,7 +53,7 @@ public SNIMarsConnection(SNIHandle lowerHandle) _lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete); } - public SNIMarsHandle CreateMarsSession(object callbackObject, bool async) + public SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async) { lock (DemuxerSync) { @@ -204,7 +204,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { using (TrySNIEventScope.Create(nameof(SNIMarsConnection))) { - SNISMUXHeader currentHeader = null; SNIPacket currentPacket = null; SNIMarsHandle currentSession = null; @@ -224,7 +223,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { if (_currentHeaderByteCount != SNISMUXHeader.HEADER_LENGTH) { - currentHeader = null; currentPacket = null; currentSession = null; @@ -250,17 +248,16 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) } _currentHeader.Read(_headerBytes); - _dataBytesLeft = (int)_currentHeader.length; - _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)_currentHeader.length); + _dataBytesLeft = (int)_currentHeader.Length; + _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)_currentHeader.Length); #if DEBUG - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: _currentHeader?.length); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: currentHeader.Length); #endif } - currentHeader = _currentHeader; currentPacket = _currentPacket; - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_DATA) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { if (_dataBytesLeft > 0) { @@ -286,44 +283,44 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _currentHeaderByteCount = 0; - if (!_sessions.ContainsKey(_currentHeader.sessionId)) + if (!_sessions.ContainsKey(_currentHeader.SessionId)) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.SMUX_PROV, 0, SNICommon.InvalidParameterError, Strings.SNI_ERROR_5); HandleReceiveError(packet); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: _currentHeader?.sessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: _currentHeader.SessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); _lowerHandle.Dispose(); _lowerHandle = null; return; } - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_FIN) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_FIN) { - _sessions.Remove(_currentHeader.sessionId); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + _sessions.Remove(_currentHeader.SessionId); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } else { - currentSession = _sessions[_currentHeader.sessionId]; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession = _sessions[_currentHeader.SessionId]; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } } - if (currentHeader.flags == (byte)SNISMUXFlags.SMUX_DATA) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { - currentSession.HandleReceiveComplete(currentPacket, currentHeader); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession.HandleReceiveComplete(currentPacket, in _currentHeader); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_ACK) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_ACK) { try { - currentSession.HandleAck(currentHeader.highwater); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession.HandleAck(_currentHeader.Highwater); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } catch (Exception e) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: _currentHeader?.sessionId, args1: e?.Message); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: _currentHeader.SessionId, args1: e?.Message); SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, SNICommon.InternalExceptionError, e); } #if DEBUG diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs index 8246ce3d6f..5997af6c0d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -20,12 +20,11 @@ internal sealed class SNIMarsHandle : SNIHandle private readonly uint _status = TdsEnums.SNI_UNINITIALIZED; private readonly Queue _receivedPacketQueue = new Queue(); private readonly Queue _sendPacketQueue = new Queue(); - private readonly object _callbackObject; + private readonly TdsParserStateObject _callbackObject; private readonly Guid _connectionId; private readonly ushort _sessionId; private readonly ManualResetEventSlim _packetEvent = new ManualResetEventSlim(false); private readonly ManualResetEventSlim _ackEvent = new ManualResetEventSlim(false); - private readonly SNISMUXHeader _currentHeader = new SNISMUXHeader(); private readonly SNIAsyncCallback _handleSendCompleteCallback; private uint _sendHighwater = 4; @@ -76,7 +75,7 @@ public override void Dispose() /// MARS session ID /// Callback object /// true if connection is asynchronous - public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object callbackObject, bool async) + public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, TdsParserStateObject callbackObject, bool async) { _sessionId = sessionId; _connection = connection; @@ -102,9 +101,7 @@ private void SendControlPacket(SNISMUXFlags flags) #endif lock (this) { - SetupSMUXHeader(0, flags); - _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); - packet.SetHeaderActive(); + SetupSMUXHeader(packet, flags); } _connection.Send(packet); @@ -116,17 +113,19 @@ private void SendControlPacket(SNISMUXFlags flags) } } - private void SetupSMUXHeader(int length, SNISMUXFlags flags) + private void SetupSMUXHeader(SNIPacket packet, SNISMUXFlags flags) { Debug.Assert(Monitor.IsEntered(this), "must take lock on self before updating smux header"); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to smux packet without smux reservation"); + + SNISMUXHeader header = new((byte)flags, _sessionId, SNISMUXHeader.HEADER_LENGTH + (uint)packet.Length, + sequenceNumber: ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++, + _receiveHighwater); - _currentHeader.SMID = 83; - _currentHeader.flags = (byte)flags; - _currentHeader.sessionId = _sessionId; - _currentHeader.length = (uint)SNISMUXHeader.HEADER_LENGTH + (uint)length; - _currentHeader.sequenceNumber = ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++; - _currentHeader.highwater = _receiveHighwater; - _receiveHighwaterLastAck = _currentHeader.highwater; + _receiveHighwaterLastAck = _receiveHighwater; + + header.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); } /// @@ -136,11 +135,7 @@ private void SetupSMUXHeader(int length, SNISMUXFlags flags) /// The packet with the SMUx header set. private SNIPacket SetPacketSMUXHeader(SNIPacket packet) { - Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to smux packet without smux reservation"); - - SetupSMUXHeader(packet.Length, SNISMUXFlags.SMUX_DATA); - _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); - packet.SetHeaderActive(); + SetupSMUXHeader(packet, SNISMUXFlags.SMUX_DATA); #if DEBUG SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, Setting SMUX_DATA header in current header for packet {1}", args0: ConnectionId, args1: packet?._id); #endif @@ -348,7 +343,7 @@ public void HandleReceiveError(SNIPacket packet) _packetEvent.Set(); } - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1); + _callbackObject.ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1); } } @@ -365,7 +360,7 @@ public void HandleSendComplete(SNIPacket packet, uint sniErrorCode) { Debug.Assert(_callbackObject != null); - ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode); + _callbackObject.WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode); } _connection.ReturnPacket(packet); #if DEBUG @@ -399,16 +394,16 @@ public void HandleAck(uint highwater) /// /// SNI packet /// SMUX header - public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header) + public void HandleReceiveComplete(SNIPacket packet, in SNISMUXHeader header) { using (TrySNIEventScope.Create(nameof(SNIMarsHandle))) { lock (this) { - if (_sendHighwater != header.highwater) + if (_sendHighwater != header.Highwater) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, header.highwater {1}, _sendHighwater {2}, Handle Ack with header.highwater", args0: ConnectionId, args1: header?.highwater, args2: _sendHighwater); - HandleAck(header.highwater); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, header.highwater {1}, _sendHighwater {2}, Handle Ack with header.highwater", args0: ConnectionId, args1: header.Highwater, args2: _sendHighwater); + HandleAck(header.Highwater); } lock (_receivedPacketQueue) @@ -425,7 +420,7 @@ public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header) Debug.Assert(_callbackObject != null); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, _asyncReceives {3}", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater, args3: _asyncReceives); - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0); + _callbackObject.ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0); } _connection.ReturnPacket(packet); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index d39f382bd6..cdab2f8eee 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -132,7 +132,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// /// Full server name from connection string /// Timer expiration - /// Instance name /// SPN /// pre-defined SPN /// Flush packet cache @@ -149,7 +148,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) internal static SNIHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, - out byte[] instanceName, ref byte[][] spnBuffer, string serverSPN, bool flushCache, @@ -163,8 +161,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) string hostNameInCertificate, string serverCertificateFilename) { - instanceName = new byte[1]; - bool errorWithLocalDBProcessing; string localDBDataSource = GetLocalDBDataSource(fullServerName, out errorWithLocalDBProcessing); @@ -317,9 +313,13 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr { try { - details.ResolvedPort = port = isAdminConnection ? + SSRP.SSRPResult portDetails = isAdminConnection ? SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) : SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference); + + + details.ResolvedPort = port = portDetails.Port; + pendingDNSInfo = new SQLDNSInfo(hostName, portDetails.ResolvedIPAddresses); } catch (SocketException se) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs index f5f38f0efe..ff31c9183b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs @@ -66,7 +66,11 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + if (SynchronizeIO) + { + await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + try { return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); @@ -78,7 +82,10 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } finally { - _readAsyncSemaphore.Release(); + if (SynchronizeIO) + { + _readAsyncSemaphore.Release(); + } } } @@ -90,7 +97,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + if (SynchronizeIO) + { + await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + try { await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); @@ -102,7 +113,10 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella } finally { - _writeAsyncSemaphore.Release(); + if (SynchronizeIO) + { + _writeAsyncSemaphore.Release(); + } } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs index 389f25eeae..46e2f3af53 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs @@ -37,5 +37,9 @@ public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocke _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); } + + // This class is often wrapped in an SNISslStream, which also performs its own synchronisation. + // Setting this to false will disable the inner layer, since it's always synchronised by the wrapper. + public bool SynchronizeIO { get; set; } = true; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index d12e91ad62..fb7249b390 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -28,7 +28,7 @@ internal sealed class SNITCPHandle : SNIPhysicalHandle private readonly string _targetServer; private readonly object _sendSync; private readonly Socket _socket; - private NetworkStream _tcpStream; + private SNINetworkStream _tcpStream; private readonly string _hostNameInCertificate; private readonly string _serverCertificateFilename; private readonly bool _tlsFirst; @@ -187,21 +187,21 @@ public override int ProtocolVersion } else { - int portRetry = string.IsNullOrEmpty(cachedDNSInfo.Port) ? port : int.Parse(cachedDNSInfo.Port); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying with cached DNS IP Address {1} and port {2}", args0: _connectionId, args1: cachedDNSInfo.AddrIPv4, args2: cachedDNSInfo.Port); + int portRetry = cachedDNSInfo.Port == 0 ? port : cachedDNSInfo.Port; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying with cached DNS IP Address {1} and port {2}", args0: _connectionId, args1: cachedDNSInfo.CachedIPv4Address, args2: cachedDNSInfo.Port); - string firstCachedIP; - string secondCachedIP; + IPAddress[] firstCachedIP; + IPAddress[] secondCachedIP; if (SqlConnectionIPAddressPreference.IPv6First == ipPreference) { - firstCachedIP = cachedDNSInfo.AddrIPv6; - secondCachedIP = cachedDNSInfo.AddrIPv4; + firstCachedIP = new[] { cachedDNSInfo.CachedIPv6Address }; + secondCachedIP = new[] { cachedDNSInfo.CachedIPv4Address }; } else { - firstCachedIP = cachedDNSInfo.AddrIPv4; - secondCachedIP = cachedDNSInfo.AddrIPv6; + firstCachedIP = new[] { cachedDNSInfo.CachedIPv4Address }; + secondCachedIP = new[] { cachedDNSInfo.CachedIPv6Address }; } try @@ -274,7 +274,7 @@ public override int ProtocolVersion _sslOverTdsStream = new SslOverTdsStream(_tcpStream, _connectionId); stream = _sslOverTdsStream; } - _sslStream = new SNISslStream(stream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate)); + _sslStream = new SNISslStream(stream, true, ValidateServerCertificate); } catch (SocketException se) { @@ -300,13 +300,10 @@ public override int ProtocolVersion // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { - Socket availableSocket = null; - Task connectTask; - bool isInfiniteTimeOut = timeout.IsInfinite; - - IPAddress[] serverAddresses = isInfiniteTimeOut - ? SNICommon.GetDnsIpAddresses(hostName) - : SNICommon.GetDnsIpAddresses(hostName, timeout); + IPAddress[] serverAddresses = pendingDNSInfo?.SpeculativeIPAddresses ?? + (timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(hostName, SqlConnectionIPAddressPreference.UsePlatformDefault, false).Result + : SNICommon.GetDnsIpAddresses(hostName, timeout, SqlConnectionIPAddressPreference.UsePlatformDefault, false).Result); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -314,32 +311,40 @@ private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeou callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} serverAddresses.Length {1} Exception: {2}", args0: _connectionId, args1: serverAddresses.Length, args2: Strings.SNI_ERROR_47); ReportTcpSNIError(0, SNICommon.MultiSubnetFailoverWithMoreThan64IPs, Strings.SNI_ERROR_47); - return availableSocket; + return null; } - string IPv4String = null; - string IPv6String = null; + return TryConnectParallel(serverAddresses, port, timeout, ref callerReportError, cachedFQDN, ref pendingDNSInfo); + } + + private Socket TryConnectParallel(IPAddress[] serverAddresses, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + { + Socket availableSocket = null; + Task connectTask; + bool isInfiniteTimeOut = timeout.IsInfinite; + IPAddress ipv4Address = null; + IPAddress ipv6Address = null; foreach (IPAddress ipAddress in serverAddresses) { if (ipAddress.AddressFamily == AddressFamily.InterNetwork) { - IPv4String = ipAddress.ToString(); + ipv4Address = ipAddress; } else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { - IPv6String = ipAddress.ToString(); + ipv6Address = ipAddress; } } - if (IPv4String != null || IPv6String != null) + if (ipv4Address != null || ipv6Address != null) { - pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString()); + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, ipv4Address, ipv6Address, port); } connectTask = ParallelConnectAsync(serverAddresses, port); - if (!(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) + if (connectTask.Status != TaskStatus.RanToCompletion && !(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) { callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40); @@ -351,61 +356,36 @@ private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeou return availableSocket; } - /// - /// Returns array of IP addresses for the given server name, sorted according to the given preference. - /// - /// Thrown when ipPreference is not supported - private static IEnumerable GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference) - { - IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); - AddressFamily? prioritiesFamily = ipPreference switch - { - SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, - SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6, - SqlConnectionIPAddressPreference.UsePlatformDefault => null, - _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference)) - }; - - // Return addresses of the preferred family first - if (prioritiesFamily != null) - { - foreach (IPAddress ipAddress in ipAddresses) - { - if (ipAddress.AddressFamily == prioritiesFamily) - { - yield return ipAddress; - } - } - } - - // Return addresses of the other family - foreach (IPAddress ipAddress in ipAddresses) - { - if (ipAddress.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6) - { - if (prioritiesFamily == null || ipAddress.AddressFamily != prioritiesFamily) - { - yield return ipAddress; - } - } - } - } - // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private static Socket Connect(string serverName, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + { + IPAddress[] ipAddresses = pendingDNSInfo?.SpeculativeIPAddresses ?? SNICommon.GetDnsIpAddresses(serverName, ipPreference, false).Result; + + return Connect(ipAddresses, port, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); + } + + private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); bool isInfiniteTimeout = timeout.IsInfinite; - - IEnumerable ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); + IPEndPoint ipEndPoint = null; foreach (IPAddress ipAddress in ipAddresses) { bool isSocketSelected = false; Socket socket = null; + if (ipEndPoint == null) + { + ipEndPoint = new IPEndPoint(ipAddress, port); + } + else + { + ipEndPoint.Address = ipAddress; + } + try { socket = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) @@ -423,12 +403,15 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, ipAddress.AddressFamily, isInfiniteTimeout); + CancellationTokenSource timeoutConnectionCancellationTokenSource = null; + int remainingTimeout = timeout.MillisecondsRemainingInt; bool isConnected; + try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select { if (isInfiniteTimeout) { - socket.Connect(ipAddress, port); + socket.Connect(ipEndPoint); } else { @@ -436,30 +419,50 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, { return null; } + +#if NET6_0_OR_GREATER + timeoutConnectionCancellationTokenSource = new CancellationTokenSource(remainingTimeout); + ValueTask socketConnectValueTask = socket.ConnectAsync(ipEndPoint, timeoutConnectionCancellationTokenSource.Token); + + if (! socketConnectValueTask.IsCompleted) + { + try + { + socketConnectValueTask.AsTask().Wait(); + } + catch (OperationCanceledException) + { + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); + } + } + +#else // Socket.Connect does not support infinite timeouts, so we use Task to simulate it - Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); + Task socketConnectTask = new Task(() => socket.Connect(ipEndPoint)); socketConnectTask.ConfigureAwait(false); socketConnectTask.Start(); - int remainingTimeout = timeout.MillisecondsRemainingInt; + if (!socketConnectTask.Wait(remainingTimeout)) { + timeoutConnectionCancellationTokenSource.Cancel(); throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); } + throw SQL.SocketDidNotThrow(); +#endif } isConnected = true; } - catch (AggregateException aggregateException) when (!isInfiniteTimeout - && aggregateException.InnerException is SocketException socketException - && socketException.SocketErrorCode == SocketError.WouldBlock) + catch (SocketException socketException) when (!isInfiniteTimeout + && socketException.SocketErrorCode == SocketError.WouldBlock) { // https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118 // Socket.Select is used because it supports timeouts, while Socket.Connect does not - List checkReadLst; - List checkWriteLst; - List checkErrorLst; + List checkReadLst = new (1); + List checkWriteLst = new(1); + List checkErrorLst = new(1); // Repeating Socket.Select several times if our timeout is greater // than int.MaxValue microseconds because of @@ -469,15 +472,16 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, { if (timeout.IsExpired) { - return null; + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); + //return null; } int socketSelectTimeout = checked((int)(Math.Min(timeout.MillisecondsRemainingInt, int.MaxValue / 1000) * 1000)); - checkReadLst = new List(1) { socket }; - checkWriteLst = new List(1) { socket }; - checkErrorLst = new List(1) { socket }; + checkReadLst.Add(socket); + checkWriteLst.Add(socket); + checkErrorLst.Add(socket); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Determining the status of the socket during the remaining timeout of {0} microseconds.", @@ -490,21 +494,29 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, // workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538 isConnected = socket.Connected && checkErrorLst.Count == 0; } + finally + { + timeoutConnectionCancellationTokenSource?.Dispose(); + } if (isConnected) { socket.Blocking = true; - string iPv4String = null; - string iPv6String = null; - if (socket.AddressFamily == AddressFamily.InterNetwork) + + IPAddress ipv4Address = null; + IPAddress ipv6Address = null; + + if (ipAddress.AddressFamily == AddressFamily.InterNetwork) { - iPv4String = ipAddress.ToString(); + ipv4Address = ipAddress; } else { - iPv6String = ipAddress.ToString(); + ipv6Address = ipAddress; } - pendingDNSInfo = new SQLDNSInfo(cachedFQDN, iPv4String, iPv6String, port.ToString()); + + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, ipv4Address, ipv6Address, port); + isSocketSelected = true; return socket; } @@ -525,7 +537,7 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, return null; } - private static Task ParallelConnectAsync(IPAddress[] serverAddresses, int port) + private static async Task ParallelConnectAsync(IPAddress[] serverAddresses, int port) { if (serverAddresses == null) { @@ -536,93 +548,128 @@ private static Task ParallelConnectAsync(IPAddress[] serverAddresses, in throw new ArgumentOutOfRangeException(nameof(serverAddresses)); } - var sockets = new List(serverAddresses.Length); - var connectTasks = new List(serverAddresses.Length); - var tcs = new TaskCompletionSource(); - var lastError = new StrongBox(); - var pendingCompleteCount = new StrongBox(serverAddresses.Length); + using var connectCancellationTokenSource = new CancellationTokenSource(); + Exception lastException = null; + IPEndPoint ipEndPoint = null; + + Task lastTask = null; + + List emptySocketList = new List(); + List socketErrorCheckList = new List(1); + Dictionary socketConnectionTasks = new(serverAddresses.Length); + Socket completedSocket; foreach (IPAddress address in serverAddresses) { var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - sockets.Add(socket); + + if (ipEndPoint == null) + { + ipEndPoint = new IPEndPoint(address, port); + } + else + { + ipEndPoint.Address = address; + } // Start all connection tasks now, to prevent possible race conditions with // calling ConnectAsync on disposed sockets. try { - connectTasks.Add(socket.ConnectAsync(address, port)); +#if NET6_0_OR_GREATER + lastTask = socket.ConnectAsync(ipEndPoint, connectCancellationTokenSource.Token).AsTask(); +#else + lastTask = socket.ConnectAsync(ipEndPoint); +#endif + socketConnectionTasks.Add(lastTask, socket); } catch (Exception e) { - connectTasks.Add(Task.FromException(e)); + socketConnectionTasks.Add(Task.FromException(e), socket); } } - for (int i = 0; i < sockets.Count; i++) - { - ParallelConnectHelper(sockets[i], connectTasks[i], tcs, pendingCompleteCount, lastError, sockets); - } - - return tcs.Task; - } - - private static async void ParallelConnectHelper( - Socket socket, - Task connectTask, - TaskCompletionSource tcs, - StrongBox pendingCompleteCount, - StrongBox lastError, - List sockets) - { - bool success = false; try { - // Try to connect. If we're successful, store this task into the result task. - await connectTask.ConfigureAwait(false); - success = tcs.TrySetResult(socket); - if (success) - { - // Whichever connection completes the return task is responsible for disposing - // all of the sockets (except for whichever one is stored into the result task). - // This ensures that only one thread will attempt to dispose of a socket. - // This is also the closest thing we have to canceling connect attempts. - foreach (Socket otherSocket in sockets) + while (socketConnectionTasks.Count > 0) + { + Task completedTask; + + // If there's only one IP address, we can avoid the implicit Task allocation of Task.WhenAny + if (socketConnectionTasks.Count == 1) { - if (otherSocket != socket) + completedTask = lastTask; + + try { - otherSocket.Dispose(); + if (completedTask.Status != TaskStatus.Faulted) + { + await completedTask.ConfigureAwait(false); + } + } + catch (Exception connectException) + { + // This exception is silently swallowed here, but is thrown later in the method + SqlClientEventSource.Log.TryAdvancedTraceEvent( + $"{nameof(SNITCPHandle)}.{nameof(ParallelConnectAsync)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {connectException}"); } - } - } - } - catch (Exception e) - { - // Store an exception to be published if no connection succeeds - Interlocked.Exchange(ref lastError.Value, e); - } - finally - { - // If we didn't successfully transition the result task to completed, - // then someone else did and they would have cleaned up, so there's nothing - // more to do. Otherwise, no one completed it yet or we failed; either way, - // see if we're the last outstanding connection, and if we are, try to complete - // the task, and if we're successful, it's our responsibility to dispose all of the sockets. - if (!success && Interlocked.Decrement(ref pendingCompleteCount.Value) == 0) - { - if (lastError.Value != null) - { - tcs.TrySetException(lastError.Value); } else { - tcs.TrySetCanceled(); + completedTask = await Task.WhenAny(socketConnectionTasks.Keys).ConfigureAwait(false); } - foreach (Socket s in sockets) + Socket taskSocket = socketConnectionTasks[completedTask]; + + if (completedTask.Status == TaskStatus.RanToCompletion) { - s.Dispose(); + // workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538 + if (socketErrorCheckList.Count > 0) + { + socketErrorCheckList.Clear(); + } + socketErrorCheckList.Add(taskSocket); + + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Determining the status of the socket following completion of ConnectAsync."); + + Socket.Select(emptySocketList, emptySocketList, socketErrorCheckList, 0); + if (taskSocket.Connected && socketErrorCheckList.Count == 0) + { + + completedSocket = taskSocket; + connectCancellationTokenSource.Cancel(); + socketConnectionTasks.Remove(completedTask); + + lastException = null; + return completedSocket; + } } + else + { + if (completedTask.Status == TaskStatus.Faulted) + { + lastException = completedTask.Exception; + } + } + + taskSocket.Dispose(); + socketConnectionTasks.Remove(completedTask); + } + + if (lastException != null) + throw lastException; + + connectCancellationTokenSource.Token.ThrowIfCancellationRequested(); + // This return statement can never be reached. socketConnectionsTasks would need to have been drained, but this can only happen if all tasks + // inside it have succeeded, failed or have been cancelled. Success will result in an early return, failure results in a thrown exception, + // cancellation results in another exception. + return null; + } + finally + { + foreach (KeyValuePair socketConnectionTaskMapping in socketConnectionTasks) + { + socketConnectionTaskMapping.Value.Dispose(); } } } @@ -647,9 +694,11 @@ public override uint EnableSsl(uint options) // TODO: Resolve whether to send _serverNameIndication or _targetServer. _serverNameIndication currently results in error. Why? _sslStream.AuthenticateAsClient(_targetServer, null, s_supportedProtocols, false); } + _sslStream.Flush(); if (_sslOverTdsStream is not null) { _sslOverTdsStream.FinishHandshake(); + _sslOverTdsStream.Flush(); } } catch (AuthenticationException aue) @@ -664,6 +713,7 @@ public override uint EnableSsl(uint options) } _stream = _sslStream; + _tcpStream.SynchronizeIO = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, SSL enabled successfully.", args0: _connectionId); return TdsEnums.SNI_SUCCESS; } @@ -674,11 +724,15 @@ public override uint EnableSsl(uint options) /// public override void DisableSsl() { + _tcpStream.SynchronizeIO = true; + _sslStream.Flush(); _sslStream.Dispose(); _sslStream = null; + _sslOverTdsStream?.Flush(); _sslOverTdsStream?.Dispose(); _sslOverTdsStream = null; _stream = _tcpStream; + _stream.Flush(); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, SSL Disabled. Communication will continue on TCP Stream.", args0: _connectionId); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 3cad605caa..6ad565c193 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics; using System.Net; @@ -10,12 +11,31 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure; using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { internal sealed class SSRP { + public sealed class SSRPResult + { + internal byte[] Buffer { get; set; } + + public ushort Port { get; set; } + + public IPAddress[] ResolvedIPAddresses { get; } + + public SSRPResult(IPAddress[] resolvedIPAddresses, byte[] buffer) + { + ResolvedIPAddresses = resolvedIPAddresses; + Buffer = buffer; + } + } + + private static readonly TimeSpan s_sendTimeout = TimeSpan.FromSeconds(1.0); + private static readonly TimeSpan s_receiveTimeout = TimeSpan.FromSeconds(1.0); + private const char SemicolonSeparator = ';'; private const int SqlServerBrowserPort = 1434; //port SQL Server Browser private const int RecieveMAXTimeoutsForCLNT_BCAST_EX = 15000; //Default max time for response wait @@ -28,23 +48,40 @@ internal sealed class SSRP /// /// Finds instance port number for given instance name. /// - /// SQL Sever Browser hostname + /// SQL Server Browser hostname /// instance name to find port number /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference - /// port number for given instance name - internal static int GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + /// port number and resolved IP addresses for given instance name + internal static SSRPResult GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; + + /// + /// Finds instance port number for given instance name. + /// + /// SQL Server Browser hostname + /// instance name to find port number + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference + /// port number and resolved IP addresses for given instance name + internal static ValueTask GetPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); + + private static async ValueTask GetPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); using (TrySNIEventScope.Create(nameof(SSRP))) { byte[] instanceInfoRequest = CreateInstanceInfoRequest(instanceName); + SSRPResult response = null; byte[] responsePacket = null; try { - responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference); + response = await SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference, async); + responsePacket = response?.Buffer; } catch (SocketException se) { @@ -74,7 +111,10 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc throw new SocketException(); } - return ushort.Parse(elements[tcpIndex + 1]); + response.Port = ushort.Parse(elements[tcpIndex + 1]); + response.Buffer = null; + + return response; } } @@ -89,12 +129,12 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) using (TrySNIEventScope.Create(nameof(SSRP))) { const byte ClntUcastInst = 0x04; - instanceName += char.MinValue; int byteCount = Encoding.ASCII.GetByteCount(instanceName); - byte[] requestPacket = new byte[byteCount + 1]; + byte[] requestPacket = new byte[byteCount + 1 + 1]; requestPacket[0] = ClntUcastInst; Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 1); + requestPacket[byteCount + 1] = 0; return requestPacket; } @@ -103,19 +143,35 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// /// Finds DAC port for given instance name. /// - /// SQL Sever Browser hostname + /// SQL Server Browser hostname /// instance name to lookup DAC port /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name - internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static SSRPResult GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; + + /// + /// Finds DAC port for given instance name. + /// + /// SQL Server Browser hostname + /// instance name to lookup DAC port + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference + /// DAC port for given instance name + internal static ValueTask GetDacPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); + + private static async ValueTask GetDacPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference); + SSRPResult response = await SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference, async); + byte[] responsePacket = response?.Buffer; const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -126,8 +182,9 @@ internal static int GetDacPortByInstanceName(string browserHostName, string inst throw new SocketException(); } - int dacPort = BitConverter.ToUInt16(responsePacket, 4); - return dacPort; + response.Port = BitConverter.ToUInt16(responsePacket, 4); + response.Buffer = null; + return response; } /// @@ -141,23 +198,17 @@ private static byte[] CreateDacPortInfoRequest(string instanceName) const byte ClntUcastDac = 0x0F; const byte ProtocolVersion = 0x01; - instanceName += char.MinValue; int byteCount = Encoding.ASCII.GetByteCount(instanceName); - byte[] requestPacket = new byte[byteCount + 2]; + byte[] requestPacket = new byte[byteCount + 2 + 1]; requestPacket[0] = ClntUcastDac; requestPacket[1] = ProtocolVersion; Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 2); + requestPacket[2 + byteCount] = 0; return requestPacket; } - private class SsrpResult - { - public byte[] ResponsePacket; - public Exception Error; - } - /// /// Sends request to server, and receives response from server by UDP. /// @@ -167,8 +218,9 @@ private class SsrpResult /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference + /// If true, this method will be run asynchronously /// response packet from UDP server - private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + private static async ValueTask SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -176,96 +228,96 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re Debug.Assert(port >= 0 && port <= 65535, "Invalid port"); Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array"); + IPAddress[] ipAddresses; + if (IPAddress.TryParse(browserHostname, out IPAddress address)) { - SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel); - if (response != null && response.ResponsePacket != null) - return response.ResponsePacket; - else if (response != null && response.Error != null) - throw response.Error; - else - return null; + ipAddresses = new IPAddress[1] { address }; + } + else + { + ipAddresses = await (timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(browserHostname, ipPreference, async) + : SNICommon.GetDnsIpAddresses(browserHostname, timeout, ipPreference, async)); } - - IPAddress[] ipAddresses = timeout.IsInfinite - ? SNICommon.GetDnsIpAddresses(browserHostname) - : SNICommon.GetDnsIpAddresses(browserHostname, timeout); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); - IPAddress[] ipv4Addresses = null; - IPAddress[] ipv6Addresses = null; + + byte[] response = null; + Exception responseException = null; + switch (ipPreference) { case SqlConnectionIPAddressPreference.IPv4First: + case SqlConnectionIPAddressPreference.IPv6First: { - SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); - - SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); - if (response4 != null && response4.ResponsePacket != null) + // If the ipPreference has been specified and IP addresses of a certain address family are attempted first, + // then slice the array at the point where the address family changes. + AddressFamily previousAddressFamily = ipAddresses[0].AddressFamily; + int firstAddressFamilyLength = 0; + Memory primaryIpAddressList; + Memory secondaryIpAddressList = Memory.Empty; + + for (int i = 0; i < ipAddresses.Length; i++) { - return response4.ResponsePacket; - } + if (ipAddresses[i].AddressFamily != previousAddressFamily) + { + firstAddressFamilyLength = i; + break; + } - SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); - if (response6 != null && response6.ResponsePacket != null) - { - return response6.ResponsePacket; + previousAddressFamily = ipAddresses[i].AddressFamily; } + primaryIpAddressList = firstAddressFamilyLength == 0 + ? ipAddresses.AsMemory() + : ipAddresses.AsMemory(0, firstAddressFamilyLength); - // No responses so throw first error - if (response4 != null && response4.Error != null) - { - throw response4.Error; - } - else if (response6 != null && response6.Error != null) + try { - throw response6.Error; - } - - break; - } - case SqlConnectionIPAddressPreference.IPv6First: - { - SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); + response = await SendUDPRequest(primaryIpAddressList, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); - SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); - if (response6 != null && response6.ResponsePacket != null) - { - return response6.ResponsePacket; + if (response != null) + { + return new SSRPResult(ipAddresses, response); + } } + catch (Exception e) + { responseException ??= e; } - SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); - if (response4 != null && response4.ResponsePacket != null) + + if (firstAddressFamilyLength > 0) { - return response4.ResponsePacket; + secondaryIpAddressList = ipAddresses.AsMemory(firstAddressFamilyLength); + + try + { + response = await SendUDPRequest(secondaryIpAddressList, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + + if (response != null) + { + return new SSRPResult(ipAddresses, response); + } + } + catch (Exception e) + { responseException ??= e; } } // No responses so throw first error - if (response6 != null && response6.Error != null) - { - throw response6.Error; - } - else if (response4 != null && response4.Error != null) + if (responseException != null) { - throw response4.Error; + throw responseException; } break; } default: - { - SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel); - if (response != null && response.ResponsePacket != null) - { - return response.ResponsePacket; - } - else if (response != null && response.Error != null) - { - throw response.Error; - } + byte[] buffer = await SendUDPRequest(ipAddresses, port, requestPacket, true, async).ConfigureAwait(false); - break; + if (response != null) + { + return new SSRPResult(ipAddresses, response); } + break; } return null; @@ -279,186 +331,211 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re /// UDP server port /// request packet /// query all resolved IP addresses in parallel + /// If true, this method will be run asynchronously /// response packet from UDP server - private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel) + private static async ValueTask SendUDPRequest(Memory ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) { - if (ipAddresses.Length == 0) + if (ipAddresses.IsEmpty) return null; - if (allIPsInParallel) // Used for MultiSubnetFailover + IPEndPoint endPoint = new IPEndPoint(ipAddresses.Span[0], port); + + if (allIPsInParallel && ipAddresses.Length > 1) // Used for MultiSubnetFailover { - List> tasks = new(ipAddresses.Length); + List> tasks = new(ipAddresses.Length); + Task firstFailedTask = null; CancellationTokenSource cts = new CancellationTokenSource(); + // Cache the UdpClients for each of the address families to save disposing them + UdpClient ipv4UdpClient = null; + UdpClient ipv6UdpClient = null; + for (int i = 0; i < ipAddresses.Length; i++) { - IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); - tasks.Add(Task.Factory.StartNew(() => SendUDPRequest(endPoint, requestPacket), cts.Token)); - } + if (i > 0) + { + endPoint.Address = ipAddresses.Span[i]; + } - List> completedTasks = new(); - while (tasks.Count > 0) - { - int first = Task.WaitAny(tasks.ToArray()); - if (tasks[first].Result.ResponsePacket != null) + if (endPoint.AddressFamily == AddressFamily.InterNetwork) { - cts.Cancel(); - return tasks[first].Result; + ipv4UdpClient ??= new UdpClient(AddressFamily.InterNetwork); + + tasks.Add(SendUDPRequest(endPoint, ipv4UdpClient, requestPacket, async, cts.Token)); } - else + else if (endPoint.AddressFamily == AddressFamily.InterNetworkV6) { - completedTasks.Add(tasks[first]); - tasks.Remove(tasks[first]); + ipv6UdpClient ??= new UdpClient(AddressFamily.InterNetworkV6); + + tasks.Add(SendUDPRequest(endPoint, ipv4UdpClient, requestPacket, async, cts.Token)); } } - Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0"); + using (ipv4UdpClient) + using (ipv6UdpClient) + { + while (tasks.Count > 0) + { + Task completedTask; + + if (async) + { + completedTask = await Task.WhenAny(tasks).ConfigureAwait(false); - // All tasks failed. Return the error from the first failure. - return completedTasks[0].Result; + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } + } + else + { + int completedTaskIndex = Task.WaitAny(tasks.ToArray()); + + completedTask = tasks[completedTaskIndex]; + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } + } + + if (completedTask.Status == TaskStatus.Faulted) + { + tasks.Remove(completedTask); + firstFailedTask ??= completedTask; + } + } + + Debug.Assert(firstFailedTask != null, "firstFailedTask should never be null"); + + // All tasks failed. Return the error from the first failure. + throw firstFailedTask.Exception; + } } else { - // If not parallel, use the first IP address provided - IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); - return SendUDPRequest(endPoint, requestPacket); + using (UdpClient oneShotUdpClient = new UdpClient(endPoint.AddressFamily)) + { + // If not parallel, use the first IP address provided + return await SendUDPRequest(endPoint, oneShotUdpClient, requestPacket, async, CancellationToken.None); + } } } - private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket) +#if NET6_0_OR_GREATER + private static async Task SendUDPRequest(IPEndPoint endPoint, UdpClient client, byte[] requestPacket, bool async, CancellationToken token) +#else + private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client, byte[] requestPacket, bool async, CancellationToken token) +#endif { - const int sendTimeOutMs = 1000; - const int receiveTimeOutMs = 1000; - - SsrpResult result = new(); + byte[] responsePacket = null; try { - using (UdpClient client = new UdpClient(endPoint.AddressFamily)) + + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); + + using (CancellationTokenSource sendTimeoutCancellationTokenSource = new CancellationTokenSource(s_sendTimeout)) + using (CancellationTokenSource sendCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, sendTimeoutCancellationTokenSource.Token)) { +#if NET6_0_OR_GREATER + ValueTask sendTask = client.SendAsync(requestPacket.AsMemory(), endPoint, sendCancellationTokenSource.Token); + + if (async) + { + await sendTask.ConfigureAwait(false); + } + else + { + if (!sendTask.IsCompleted) + { + sendTask.AsTask().Wait(); + } + } +#else Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); - Task receiveTask = null; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); - if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs)) + sendTask.Wait(sendCancellationTokenSource.Token); +#endif + } + + UdpReceiveResult receiveResult; + + using (CancellationTokenSource receiveTimeoutCancellationTokenSource = new CancellationTokenSource(s_receiveTimeout)) + using (CancellationTokenSource receiveCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, receiveTimeoutCancellationTokenSource.Token)) + { +#if NET6_0_OR_GREATER + ValueTask receiveTask = client.ReceiveAsync(receiveCancellationTokenSource.Token); + + if (async) + { + receiveResult = await receiveTask.ConfigureAwait(false); + } + else { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); - result.ResponsePacket = receiveTask.Result.Buffer; + if (!receiveTask.IsCompleted) + { + receiveTask.AsTask().Wait(); + } + + receiveResult = receiveTask.Result; } +#else + Task receiveTask = client.ReceiveAsync(); + + receiveTask.Wait(receiveCancellationTokenSource.Token); + receiveResult = receiveTask.Result; +#endif + + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); + responsePacket = receiveResult.Buffer; } + + } + catch (AggregateException ae) when (ae.InnerException is OperationCanceledException) + { + responsePacket = null; } catch (AggregateException ae) { if (ae.InnerExceptions.Count > 0) { + Exception firstSocketException = null; + // Log all errors foreach (Exception e in ae.InnerExceptions) { // Favor SocketException for returned error if (e is SocketException) { - result.Error = e; + firstSocketException = e; } SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: e.Message); } - // Return first error if we didn't find a SocketException - result.Error = result.Error == null ? ae.InnerExceptions[0] : result.Error; + // Throw first error if we didn't find a SocketException + throw firstSocketException ?? ae.InnerExceptions[0]; } else { - result.Error = ae; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: ae.Message); + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: ae.Message); + throw; } } catch (Exception e) { - result.Error = e; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); - } - - return result; - } - - /// - /// Sends request to server, and recieves response from server (SQLBrowser) on port 1434 by UDP - /// Request (https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/a3035afa-c268-4699-b8fd-4f351e5c8e9e) - /// Response (https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/2e1560c9-5097-4023-9f5e-72b9ff1ec3b1) - /// - /// string constaning list of SVR_RESP(just RESP_DATA) - internal static string SendBroadcastUDPRequest() - { - StringBuilder response = new StringBuilder(); - byte[] CLNT_BCAST_EX_Request = new byte[1] { CLNT_BCAST_EX }; //0x02 - // Waits 5 seconds for the first response and every 1 second up to 15 seconds - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/f2640a2d-3beb-464b-a443-f635842ebc3e#Appendix_A_3 - int currentTimeOut = FirstTimeoutForCLNT_BCAST_EX; - - using (TrySNIEventScope.Create(nameof(SSRP))) - { - using (UdpClient clientListener = new UdpClient()) - { - Task sendTask = clientListener.SendAsync(CLNT_BCAST_EX_Request, CLNT_BCAST_EX_Request.Length, new IPEndPoint(IPAddress.Broadcast, SqlServerBrowserPort)); - Task receiveTask = null; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch list of instances."); - Stopwatch sw = new Stopwatch(); //for waiting until 15 sec elapsed - sw.Start(); - try - { - while ((receiveTask = clientListener.ReceiveAsync()).Wait(currentTimeOut) && sw.ElapsedMilliseconds <= RecieveMAXTimeoutsForCLNT_BCAST_EX && receiveTask != null) - { - currentTimeOut = RecieveTimeoutsForCLNT_BCAST_EX; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received instnace info from UDP Client."); - if (receiveTask.Result.Buffer.Length < ValidResponseSizeForCLNT_BCAST_EX) //discard invalid response - { - response.Append(Encoding.ASCII.GetString(receiveTask.Result.Buffer, ServerResponseHeaderSizeForCLNT_BCAST_EX, receiveTask.Result.Buffer.Length - ServerResponseHeaderSizeForCLNT_BCAST_EX)); //RESP_DATA(VARIABLE) - 3 (RESP_SIZE + SVR_RESP) - } - } - } - finally - { - sw.Stop(); - } - } + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: e.Message); + throw; } - return response.ToString(); - } - - private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses) - { - ipv4Addresses = Array.Empty(); - ipv6Addresses = Array.Empty(); - - if (input != null && input.Length > 0) - { - List v4 = new List(1); - List v6 = new List(0); - - for (int index = 0; index < input.Length; index++) - { - switch (input[index].AddressFamily) - { - case AddressFamily.InterNetwork: - v4.Add(input[index]); - break; - case AddressFamily.InterNetworkV6: - v6.Add(input[index]); - break; - } - } - - if (v4.Count > 0) - { - ipv4Addresses = v4.ToArray(); - } - if (v6.Count > 0) - { - ipv6Addresses = v6.ToArray(); - } - } +#if NET6_0_OR_GREATER + return responsePacket; +#else + return Task.FromResult(responsePacket); +#endif } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index be8d1a0160..3e1b9e7019 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Threading; using System.Threading.Tasks; @@ -63,7 +64,7 @@ public override int Read(Span buffer) } while (headerBytesRead < TdsEnums.HEADER_LEN); // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + _packetBytes = BinaryPrimitives.ReadUInt16BigEndian(headerBytes.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - TdsEnums.HEADER_LEN; // read as much from the packet as the caller can accept int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); @@ -149,7 +150,7 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } while (headerBytesRead < TdsEnums.HEADER_LEN); // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + _packetBytes = BinaryPrimitives.ReadUInt16BigEndian(headerBytes.AsSpan(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - TdsEnums.HEADER_LEN; ArrayPool.Shared.Return(headerBytes, clearArray: true); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 7e04f6b9c5..b2f76135bd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -729,6 +729,14 @@ internal void PutSession(TdsParserStateObject session) // prelogin so that we don't try to negotiate encryption again during ConsumePreLoginHandshake. _encryptionOption = EncryptionOptions.NOT_SUP; } + else if (encrypt == SqlConnectionEncryptOption.Mandatory) + { + _encryptionOption = EncryptionOptions.ON; + } + else + { + _encryptionOption = EncryptionOptions.OFF; + } // PreLoginHandshake buffer consists of: // 1) Standard header, with type = MT_PRELOGIN @@ -742,126 +750,94 @@ internal void PutSession(TdsParserStateObject session) // Initialize option offset into payload buffer // 5 bytes for each option (1 byte length, 2 byte offset, 2 byte payload length) - int offset = (int)PreLoginOptions.NUMOPT * 5 + 1; + ushort headerOffset = 0; + ushort headerLength = (ushort)PreLoginOptions.NUMOPT * 5; - byte[] payload = new byte[(int)PreLoginOptions.NUMOPT * 5 + TdsEnums.MAX_PRELOGIN_PAYLOAD_LENGTH]; - int payloadLength = 0; + ushort payloadStart = (ushort)(headerLength + 1); + ushort payloadOffset = payloadStart; + // The payload length is static for each connection string. The lengths of each option are well-known + // Version: 6 bytes; Encryption: 1 byte; Instance: (instanceName.Length + 1) bytes; + // Thread ID: 4 bytes; MARS enablement: 1 byte; Trace: (2 * GUID + 1 * uint); Federated Authentication: 1 byte + // End-of-payload marker: 1 byte + // .NET Core uses a static zero-length instance name, which suggests a 51-byte payload. + // .NET Framework allows a variable-length instance name (up to 254 bytes). This is up to 305 bytes at most. + int payloadLength = 6 + 1 + (instanceName.Length + 1) + 4 + 1 + (GUID_SIZE + GUID_SIZE + 4) + 1 + 1; - for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++) + int totalBufferLength = headerLength + payloadLength; + Span preLoginPacketBuffer = stackalloc byte[totalBufferLength]; + + for (byte option = (byte)PreLoginOptions.VERSION; option < (byte)PreLoginOptions.NUMOPT; option++) { - int optionDataSize = 0; + ushort optionDataSize = 0; + // Structure header: // Fill in the option - _physicalStateObj.WriteByte((byte)option); + preLoginPacketBuffer[headerOffset] = option; // Fill in the offset of the option data - _physicalStateObj.WriteByte((byte)((offset & 0xff00) >> 8)); // send upper order byte - _physicalStateObj.WriteByte((byte)(offset & 0x00ff)); // send lower order byte + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1), payloadOffset); switch (option) { - case (int)PreLoginOptions.VERSION: + case (byte)PreLoginOptions.VERSION: Version systemDataVersion = ADP.GetAssemblyVersion(); // Major and minor - payload[payloadLength++] = (byte)(systemDataVersion.Major & 0xff); - payload[payloadLength++] = (byte)(systemDataVersion.Minor & 0xff); + preLoginPacketBuffer[payloadOffset] = (byte)(systemDataVersion.Major & 0xff); + preLoginPacketBuffer[payloadOffset + 1] = (byte)(systemDataVersion.Minor & 0xff); // Build (Big Endian) - payload[payloadLength++] = (byte)((systemDataVersion.Build & 0xff00) >> 8); - payload[payloadLength++] = (byte)(systemDataVersion.Build & 0xff); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(payloadOffset + 2), (ushort)(systemDataVersion.Build & 0xFFFF)); // Sub-build (Little Endian) - payload[payloadLength++] = (byte)(systemDataVersion.Revision & 0xff); - payload[payloadLength++] = (byte)((systemDataVersion.Revision & 0xff00) >> 8); - offset += 6; + BinaryPrimitives.WriteUInt16LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + 2 + sizeof(ushort)), (ushort)(systemDataVersion.Revision & 0xFFFF)); + optionDataSize = 6; break; - case (int)PreLoginOptions.ENCRYPT: - if (_encryptionOption == EncryptionOptions.NOT_SUP) - { - //If OS doesn't support encryption and encryption is not required, inform server "not supported" by client. - payload[payloadLength] = (byte)EncryptionOptions.NOT_SUP; - } - else - { - // Else, inform server of user request. - if (encrypt == SqlConnectionEncryptOption.Mandatory) - { - payload[payloadLength] = (byte)EncryptionOptions.ON; - _encryptionOption = EncryptionOptions.ON; - } - else - { - payload[payloadLength] = (byte)EncryptionOptions.OFF; - _encryptionOption = EncryptionOptions.OFF; - } - } + case (byte)PreLoginOptions.ENCRYPT: + preLoginPacketBuffer[payloadOffset] = (byte)_encryptionOption; - payloadLength += 1; - offset += 1; optionDataSize = 1; break; - case (int)PreLoginOptions.INSTANCE: - int i = 0; - - while (instanceName[i] != 0) - { - payload[payloadLength] = instanceName[i]; - payloadLength++; - i++; - } - - payload[payloadLength] = 0; // null terminate - payloadLength++; - i++; + case (byte)PreLoginOptions.INSTANCE: + instanceName.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); + preLoginPacketBuffer[payloadOffset + instanceName.Length] = 0; - offset += i; - optionDataSize = i; + optionDataSize = (ushort)(instanceName.Length + 1); break; - case (int)PreLoginOptions.THREADID: + case (byte)PreLoginOptions.THREADID: int threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); - payload[payloadLength++] = (byte)((0xff000000 & threadID) >> 24); - payload[payloadLength++] = (byte)((0x00ff0000 & threadID) >> 16); - payload[payloadLength++] = (byte)((0x0000ff00 & threadID) >> 8); - payload[payloadLength++] = (byte)(0x000000ff & threadID); - offset += 4; + BinaryPrimitives.WriteInt32BigEndian(preLoginPacketBuffer.Slice(payloadOffset), threadID); + optionDataSize = 4; break; - case (int)PreLoginOptions.MARS: - payload[payloadLength++] = (byte)(_fMARS ? 1 : 0); - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.MARS: + preLoginPacketBuffer[payloadOffset] = (byte)(_fMARS ? 1 : 0); + + optionDataSize = 1; break; - case (int)PreLoginOptions.TRACEID: - FillGuidBytes(_connHandler._clientConnectionId, payload.AsSpan(payloadLength, GUID_SIZE)); - payloadLength += GUID_SIZE; - offset += GUID_SIZE; - optionDataSize = GUID_SIZE; + case (byte)PreLoginOptions.TRACEID: + FillGuidBytes(_connHandler._clientConnectionId, preLoginPacketBuffer.Slice(payloadOffset)); ActivityCorrelator.ActivityId actId = ActivityCorrelator.Next(); - FillGuidBytes(actId.Id, payload.AsSpan(payloadLength, GUID_SIZE)); - payloadLength += GUID_SIZE; - payload[payloadLength++] = (byte)(0x000000ff & actId.Sequence); - payload[payloadLength++] = (byte)((0x0000ff00 & actId.Sequence) >> 8); - payload[payloadLength++] = (byte)((0x00ff0000 & actId.Sequence) >> 16); - payload[payloadLength++] = (byte)((0xff000000 & actId.Sequence) >> 24); - int actIdSize = GUID_SIZE + sizeof(uint); - offset += actIdSize; - optionDataSize += actIdSize; + + FillGuidBytes(actId.Id, preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE)); + BinaryPrimitives.WriteUInt32LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE + GUID_SIZE), actId.Sequence); + + optionDataSize = GUID_SIZE + GUID_SIZE + sizeof(uint); SqlClientEventSource.Log.TryTraceEvent(" ClientConnectionID {0}, ActivityID {1}", _connHandler?._clientConnectionId, actId); break; - case (int)PreLoginOptions.FEDAUTHREQUIRED: - payload[payloadLength++] = 0x01; - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.FEDAUTHREQUIRED: + preLoginPacketBuffer[payloadOffset] = 0x01; + + optionDataSize = 1; break; default: @@ -869,19 +845,22 @@ internal void PutSession(TdsParserStateObject session) break; } + payloadOffset += optionDataSize; + // Write data length - _physicalStateObj.WriteByte((byte)((optionDataSize & 0xff00) >> 8)); - _physicalStateObj.WriteByte((byte)(optionDataSize & 0x00ff)); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1 + sizeof(ushort)), optionDataSize); + + headerOffset += 1 + sizeof(ushort) + sizeof(ushort); } // Write out last option - to let server know the second part of packet completed - _physicalStateObj.WriteByte((byte)PreLoginOptions.LASTOPT); + preLoginPacketBuffer[headerOffset] = (byte)PreLoginOptions.LASTOPT; - // Write out payload - _physicalStateObj.WriteByteArray(payload, payloadLength, 0); + // Write out the full byte buffer + _physicalStateObj.WriteByteSpan(preLoginPacketBuffer); // Flush packet - _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); + _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH)?.Wait(); } private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integratedSecurity, string serverCertificateFilename) @@ -963,7 +942,11 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ { throw SQL.ParsingError(); } - byte[] payload = new byte[_physicalStateObj._inBytesPacket]; + // Most of the time, this response packet will be very small (less than 512 bytes.) + // In such a situation, borrow stack space rather than requesting an array. + Span payload = _physicalStateObj._inBytesPacket < 512 + ? stackalloc byte[_physicalStateObj._inBytesPacket] + : new byte[_physicalStateObj._inBytesPacket]; Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); result = _physicalStateObj.TryReadByteArray(payload, payload.Length); @@ -979,24 +962,22 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ throw SQL.InvalidSQLServerVersionUnknown(); } - int offset = 0; - int payloadOffset = 0; - int payloadLength = 0; - int option = payload[offset++]; + int headerOffset = 0; + ushort payloadOffset = 0; + ushort payloadLength = 0; + byte option = payload[headerOffset++]; bool serverSupportsEncryption = false; while (option != (byte)PreLoginOptions.LASTOPT) { + payloadOffset = BinaryPrimitives.ReadUInt16BigEndian(payload.Slice(headerOffset)); + payloadLength = BinaryPrimitives.ReadUInt16BigEndian(payload.Slice(headerOffset + 2)); + switch (option) { - case (int)PreLoginOptions.VERSION: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - + case (byte)PreLoginOptions.VERSION: byte majorVersion = payload[payloadOffset]; byte minorVersion = payload[payloadOffset + 1]; - int level = (payload[payloadOffset + 2] << 8) | - payload[payloadOffset + 3]; is2005OrLater = majorVersion >= 9; if (!is2005OrLater) @@ -1006,17 +987,13 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.ENCRYPT: + case (byte)PreLoginOptions.ENCRYPT: if (tlsFirst) { // Can skip/ignore this option if we are doing TDS 8. - offset += 4; break; } - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - EncryptionOptions serverOption = (EncryptionOptions)payload[payloadOffset]; /* internal enum EncryptionOptions { @@ -1069,11 +1046,8 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.INSTANCE: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - - byte ERROR_INST = 0x1; + case (byte)PreLoginOptions.INSTANCE: + const byte ERROR_INST = 0x1; byte instanceResult = payload[payloadOffset]; if (instanceResult == ERROR_INST) @@ -1086,29 +1060,21 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.THREADID: + case (byte)PreLoginOptions.THREADID: // DO NOTHING FOR THREADID - offset += 4; break; - case (int)PreLoginOptions.MARS: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - - marsCapable = (payload[payloadOffset] == 0 ? false : true); + case (byte)PreLoginOptions.MARS: + marsCapable = payload[payloadOffset] != 0; Debug.Assert(payload[payloadOffset] == 0 || payload[payloadOffset] == 1, "Value for Mars PreLoginHandshake option not equal to 1 or 0!"); break; - case (int)PreLoginOptions.TRACEID: + case (byte)PreLoginOptions.TRACEID: // DO NOTHING FOR TRACEID - offset += 4; break; case (int)PreLoginOptions.FEDAUTHREQUIRED: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - // Only 0x00 and 0x01 are accepted values from the server. if (payload[payloadOffset] != 0x00 && payload[payloadOffset] != 0x01) { @@ -1129,17 +1095,16 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; default: - Debug.Fail("UNKNOWN option in ConsumePreLoginHandshake, option:" + option); - // DO NOTHING FOR THESE UNKNOWN OPTIONS - offset += 4; - + Debug.Fail("UNKNOWN option in ConsumePreLoginHandshake, option:" + option); break; } - if (offset < payload.Length) + headerOffset += sizeof(ushort) + sizeof(ushort); + + if (headerOffset < payload.Length) { - option = payload[offset++]; + option = payload[headerOffset++]; } else { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 1e0141dd58..22c64395f0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -19,6 +19,7 @@ namespace Microsoft.Data.SqlClient.SNI { internal sealed class TdsParserStateObjectManaged : TdsParserStateObject { + private static readonly byte[] s_staticInstanceName = Array.Empty(); private SNIMarsConnection? _marsConnection; private SNIHandle? _sessionHandle; #if NET7_0_OR_GREATER @@ -56,7 +57,7 @@ protected override void CreateSessionHandle(TdsParserStateObject physicalConnect } } - internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) + internal SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async) { SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.CreateMarsSession | Info | State Object Id {0}, Session Id {1}, Async = {2}", _objectID, _sessionHandle?.ConnectionId, async); if (_marsConnection is null) @@ -98,10 +99,12 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, ref spnBuffer, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); + instanceName = s_staticInstanceName; + if (sessionHandle is not null) { _sessionHandle = sessionHandle; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 59776956a1..dffe5ac250 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -101,17 +101,17 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache result = SNINativeMethodWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); - pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI); if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv4 = IPStringFromSNI; + pendingDNSInfo.CachedIPv4Address = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv6 = IPStringFromSNI; + pendingDNSInfo.CachedIPv6Address = IPFromSNI; } } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 0a0757731b..16dc8713be 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -22,14 +22,14 @@ full - + $([System.IO.Path]::Combine('$(IntermediateOutputPath)','$(GeneratedSourceFileName)')) - + True @@ -689,7 +689,7 @@ Resources\StringsHelper.cs - + Resources\Strings.Designer.cs True True diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index b9af5849c4..1bce41ddd8 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -1095,9 +1095,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } @@ -1107,7 +1107,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string constring, ref IntPtr pConn, byte[] spnBuffer, - byte[] instanceName, + Span instanceName, bool fOverrideCache, bool fSync, int timeout, @@ -1119,7 +1119,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SQLDNSInfo cachedDNSInfo, string hostNameInCertificate) { - fixed (byte* pin_instanceName = &instanceName[0]) + fixed (byte* pin_instanceName = instanceName) { SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO(); @@ -1154,9 +1154,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); if (spnBuffer != null) { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 7a9bbfdfd3..504a555924 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -25,6 +25,7 @@ using Microsoft.Data.SqlTypes; using Microsoft.SqlServer.Server; using Microsoft.Data.ProviderBase; +using System.Buffers.Binary; namespace Microsoft.Data.SqlClient { @@ -854,17 +855,17 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) result = SNINativeMethodWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); - _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI); if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv4 = IPStringFromSNI; + _connHandler.pendingSQLDNSObject.CachedIPv4Address = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv6 = IPStringFromSNI; + _connHandler.pendingSQLDNSObject.CachedIPv6Address = IPFromSNI; } } } @@ -1059,6 +1060,27 @@ internal void BestEffortCleanup() // prelogin so that we don't try to negotiate encryption again during ConsumePreLoginHandshake. _encryptionOption = EncryptionOptions.NOT_SUP; } + else + { + if (encrypt == SqlConnectionEncryptOption.Mandatory) + { + _encryptionOption = EncryptionOptions.ON; + } + else + { + _encryptionOption = EncryptionOptions.OFF; + } + + if (clientCertificate) + { + _encryptionOption |= EncryptionOptions.CLIENT_CERT; + } + } + + if (useCtaip) + { + _encryptionOption |= EncryptionOptions.CTAIP; + } // PreLoginHandshake buffer consists of: // 1) Standard header, with type = MT_PRELOGIN @@ -1072,147 +1094,100 @@ internal void BestEffortCleanup() // Initialize option offset into payload buffer // 5 bytes for each option (1 byte length, 2 byte offset, 2 byte payload length) - int offset = (int)PreLoginOptions.NUMOPT * 5 + 1; + ushort headerOffset = 0; + ushort headerLength = (ushort)PreLoginOptions.NUMOPT * 5; - byte[] payload = new byte[(int)PreLoginOptions.NUMOPT * 5 + TdsEnums.MAX_PRELOGIN_PAYLOAD_LENGTH]; - int payloadLength = 0; + ushort payloadStart = (ushort)(headerLength + 1); + ushort payloadOffset = payloadStart; + // The payload length is static for each connection string. The lengths of each option are well-known + // Version: 6 bytes; Encryption: 1 byte; Instance: (instanceName.Length + 1) bytes; + // Thread ID: 4 bytes; MARS enablement: 1 byte; Trace: (2 * GUID + 1 * uint); Federated Authentication: 1 byte + // End-of-payload marker: 1 byte + // .NET Core uses a static zero-length instance name, which suggests a 51-byte payload. + // .NET Framework allows a variable-length instance name (up to 254 bytes). This is up to 305 bytes at most. + int payloadLength = 6 + 1 + (instanceName.Length + 1) + 4 + 1 + (GUID_SIZE + GUID_SIZE + 4) + 1 + 1; - // UNDONE - need to do some length verification to ensure packet does not - // get too big!!! Not beyond it's max length! + int totalBufferLength = headerLength + payloadLength; + byte[] preLoginPacketBufferArray = new byte[totalBufferLength]; + Span preLoginPacketBuffer = preLoginPacketBufferArray; - for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++) + for (byte option = (byte)PreLoginOptions.VERSION; option < (byte)PreLoginOptions.NUMOPT; option++) { - int optionDataSize = 0; + ushort optionDataSize = 0; + // Structure header: // Fill in the option - _physicalStateObj.WriteByte((byte)option); + preLoginPacketBuffer[headerOffset] = option; // Fill in the offset of the option data - _physicalStateObj.WriteByte((byte)((offset & 0xff00) >> 8)); // send upper order byte - _physicalStateObj.WriteByte((byte)(offset & 0x00ff)); // send lower order byte + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1), payloadOffset); switch (option) { - case (int)PreLoginOptions.VERSION: + case (byte)PreLoginOptions.VERSION: Version systemDataVersion = ADP.GetAssemblyVersion(); // Major and minor - payload[payloadLength++] = (byte)(systemDataVersion.Major & 0xff); - payload[payloadLength++] = (byte)(systemDataVersion.Minor & 0xff); + preLoginPacketBuffer[payloadOffset] = (byte)(systemDataVersion.Major & 0xff); + preLoginPacketBuffer[payloadOffset + 1] = (byte)(systemDataVersion.Minor & 0xff); // Build (Big Endian) - payload[payloadLength++] = (byte)((systemDataVersion.Build & 0xff00) >> 8); - payload[payloadLength++] = (byte)(systemDataVersion.Build & 0xff); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(payloadOffset + 2), (ushort)(systemDataVersion.Build & 0xFFFF)); // Sub-build (Little Endian) - payload[payloadLength++] = (byte)(systemDataVersion.Revision & 0xff); - payload[payloadLength++] = (byte)((systemDataVersion.Revision & 0xff00) >> 8); - offset += 6; + BinaryPrimitives.WriteUInt16LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + 2 + sizeof(ushort)), (ushort)(systemDataVersion.Revision & 0xFFFF)); + optionDataSize = 6; break; - case (int)PreLoginOptions.ENCRYPT: - if (_encryptionOption == EncryptionOptions.NOT_SUP) - { - //If OS doesn't support encryption and encryption is not required, inform server "not supported" by client. - payload[payloadLength] = (byte)EncryptionOptions.NOT_SUP; - } - else - { - // Else, inform server of user request. - if (encrypt == SqlConnectionEncryptOption.Mandatory) - { - payload[payloadLength] = (byte)EncryptionOptions.ON; - _encryptionOption = EncryptionOptions.ON; - } - else - { - payload[payloadLength] = (byte)EncryptionOptions.OFF; - _encryptionOption = EncryptionOptions.OFF; - } - - // Inform server of user request. - if (clientCertificate) - { - payload[payloadLength] |= (byte)EncryptionOptions.CLIENT_CERT; - _encryptionOption |= EncryptionOptions.CLIENT_CERT; - } - } - - // Add CTAIP if requested. - if (useCtaip) - { - payload[payloadLength] |= (byte)EncryptionOptions.CTAIP; - _encryptionOption |= EncryptionOptions.CTAIP; - } + case (byte)PreLoginOptions.ENCRYPT: + preLoginPacketBuffer[payloadOffset] = (byte)_encryptionOption; - payloadLength += 1; - offset += 1; optionDataSize = 1; break; - case (int)PreLoginOptions.INSTANCE: - int i = 0; - - while (instanceName[i] != 0) - { - payload[payloadLength] = instanceName[i]; - payloadLength++; - i++; - } + case (byte)PreLoginOptions.INSTANCE: + instanceName.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); + preLoginPacketBuffer[payloadOffset + instanceName.Length] = 0; - payload[payloadLength] = 0; // null terminate - payloadLength++; - i++; - - offset += i; - optionDataSize = i; + optionDataSize = (ushort)(instanceName.Length + 1); break; - case (int)PreLoginOptions.THREADID: - Int32 threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); + case (byte)PreLoginOptions.THREADID: + int threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); + + BinaryPrimitives.WriteInt32BigEndian(preLoginPacketBuffer.Slice(payloadOffset), threadID); - payload[payloadLength++] = (byte)((0xff000000 & threadID) >> 24); - payload[payloadLength++] = (byte)((0x00ff0000 & threadID) >> 16); - payload[payloadLength++] = (byte)((0x0000ff00 & threadID) >> 8); - payload[payloadLength++] = (byte)(0x000000ff & threadID); - offset += 4; optionDataSize = 4; break; - case (int)PreLoginOptions.MARS: - payload[payloadLength++] = (byte)(_fMARS ? 1 : 0); - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.MARS: + preLoginPacketBuffer[payloadOffset] = (byte)(_fMARS ? 1 : 0); + + optionDataSize = 1; break; - case (int)PreLoginOptions.TRACEID: + case (byte)PreLoginOptions.TRACEID: byte[] connectionIdBytes = _connHandler._clientConnectionId.ToByteArray(); + Debug.Assert(GUID_SIZE == connectionIdBytes.Length); - Buffer.BlockCopy(connectionIdBytes, 0, payload, payloadLength, GUID_SIZE); - payloadLength += GUID_SIZE; - offset += GUID_SIZE; - optionDataSize = GUID_SIZE; + connectionIdBytes.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); ActivityCorrelator.ActivityId actId = ActivityCorrelator.Next(); + connectionIdBytes = actId.Id.ToByteArray(); - Buffer.BlockCopy(connectionIdBytes, 0, payload, payloadLength, GUID_SIZE); - payloadLength += GUID_SIZE; - payload[payloadLength++] = (byte)(0x000000ff & actId.Sequence); - payload[payloadLength++] = (byte)((0x0000ff00 & actId.Sequence) >> 8); - payload[payloadLength++] = (byte)((0x00ff0000 & actId.Sequence) >> 16); - payload[payloadLength++] = (byte)((0xff000000 & actId.Sequence) >> 24); - int actIdSize = GUID_SIZE + sizeof(UInt32); - offset += actIdSize; - optionDataSize += actIdSize; + connectionIdBytes.CopyTo(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE)); + BinaryPrimitives.WriteUInt32LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE + GUID_SIZE), actId.Sequence); + + optionDataSize = GUID_SIZE + GUID_SIZE + sizeof(uint); SqlClientEventSource.Log.TryTraceEvent(" ClientConnectionID {0}, ActivityID {1}", _connHandler._clientConnectionId, actId); break; - case (int)PreLoginOptions.FEDAUTHREQUIRED: - payload[payloadLength++] = 0x01; - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.FEDAUTHREQUIRED: + preLoginPacketBuffer[payloadOffset] = 0x01; + + optionDataSize = 1; break; default: @@ -1220,19 +1195,22 @@ internal void BestEffortCleanup() break; } + payloadOffset += optionDataSize; + // Write data length - _physicalStateObj.WriteByte((byte)((optionDataSize & 0xff00) >> 8)); - _physicalStateObj.WriteByte((byte)(optionDataSize & 0x00ff)); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1 + sizeof(ushort)), optionDataSize); + + headerOffset += 1 + sizeof(ushort) + sizeof(ushort); } // Write out last option - to let server know the second part of packet completed - _physicalStateObj.WriteByte((byte)PreLoginOptions.LASTOPT); + preLoginPacketBuffer[headerOffset] = (byte)PreLoginOptions.LASTOPT; // Write out payload - _physicalStateObj.WriteByteArray(payload, payloadLength, 0); + _physicalStateObj.WriteByteArray(preLoginPacketBufferArray, preLoginPacketBufferArray.Length, 0); // Flush packet - _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); + _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH)?.Wait(); } private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integratedSecurity, string serverCertificate, ServerCertificateValidationCallback serverCallback, ClientCertificateRetrievalCallback clientCallback) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs index 9d4136d01f..5732c83092 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; +using System.Net; namespace Microsoft.Data.SqlClient { @@ -28,13 +29,12 @@ internal bool AddDNSInfo(SQLDNSInfo item) { if (null != item) { - if (DNSInfoCache.ContainsKey(item.FQDN)) - { - - DeleteDNSInfo(item.FQDN); - } - - return DNSInfoCache.TryAdd(item.FQDN, item); +#if NET6_0_OR_GREATER || NETSTANDARD2_1 + DNSInfoCache.AddOrUpdate(item.FQDN, static (key, state) => state, static (key, value, state) => state, item); +#else + DNSInfoCache.AddOrUpdate(item.FQDN, item, (key, value) => item); +#endif + return true; } return false; @@ -42,8 +42,7 @@ internal bool AddDNSInfo(SQLDNSInfo item) internal bool DeleteDNSInfo(string FQDN) { - SQLDNSInfo value; - return DNSInfoCache.TryRemove(FQDN, out value); + return DNSInfoCache.TryRemove(FQDN, out _); } internal bool GetDNSInfo(string FQDN, out SQLDNSInfo result) @@ -58,8 +57,8 @@ internal bool IsDuplicate(SQLDNSInfo newItem) SQLDNSInfo oldItem; if (GetDNSInfo(newItem.FQDN, out oldItem)) { - return (newItem.AddrIPv4 == oldItem.AddrIPv4 && - newItem.AddrIPv6 == oldItem.AddrIPv6 && + return (newItem.CachedIPv4Address == oldItem.CachedIPv4Address && + newItem.CachedIPv6Address == oldItem.CachedIPv6Address && newItem.Port == oldItem.Port); } } @@ -71,16 +70,23 @@ internal bool IsDuplicate(SQLDNSInfo newItem) internal sealed class SQLDNSInfo { public string FQDN { get; set; } - public string AddrIPv4 { get; set; } - public string AddrIPv6 { get; set; } - public string Port { get; set; } + public IPAddress CachedIPv4Address { get; set; } + public IPAddress CachedIPv6Address { get; set; } + public int Port { get; set; } + public IPAddress[] SpeculativeIPAddresses { get; set; } - internal SQLDNSInfo(string FQDN, string ipv4, string ipv6, string port) + internal SQLDNSInfo(string fqdn, IPAddress ipv4, IPAddress ipv6, int port) { - this.FQDN = FQDN; - AddrIPv4 = ipv4; - AddrIPv6 = ipv6; + FQDN = fqdn; + CachedIPv4Address = ipv4; + CachedIPv6Address = ipv6; Port = port; } + + internal SQLDNSInfo(string fqdn, IPAddress[] speculativeIPAddresses) + { + FQDN = fqdn; + SpeculativeIPAddresses = speculativeIPAddresses; + } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index 8d276cd285..facb214a72 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -18,8 +18,8 @@ internal sealed partial class SNILoadHandle : SafeHandle { internal static readonly SNILoadHandle SingletonInstance = new SNILoadHandle(); - internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate ReadAsyncCallbackDispatcher = new SNINativeMethodWrapper.SqlAsyncCallbackDelegate(ReadDispatcher); - internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate WriteAsyncCallbackDispatcher = new SNINativeMethodWrapper.SqlAsyncCallbackDelegate(WriteDispatcher); + internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate ReadAsyncCallbackDispatcher = ReadDispatcher; + internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate WriteAsyncCallbackDispatcher = WriteDispatcher; private readonly uint _sniStatus = TdsEnums.SNI_UNINITIALIZED; private readonly EncryptionOptions _encryptionOption = EncryptionOptions.OFF; @@ -164,6 +164,8 @@ internal sealed class SNIHandle : SafeHandle string hostNameInCertificate) : base(IntPtr.Zero, true) { + Span instanceNameBuffer = stackalloc byte[256]; + #if !NET6_0_OR_GREATER RuntimeHelpers.PrepareConstrainedRegions(); #endif @@ -172,7 +174,6 @@ internal sealed class SNIHandle : SafeHandle finally { _fSync = fSync; - instanceName = new byte[256]; // Size as specified by netlibs. // Option ignoreSniOpenTimeout is no longer available //if (ignoreSniOpenTimeout) //{ @@ -185,12 +186,16 @@ internal sealed class SNIHandle : SafeHandle #if NETFRAMEWORK int transparentNetworkResolutionStateNo = (int)transparentNetworkResolutionState; _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, + spnBuffer, instanceNameBuffer, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, ADP.IsAzureSqlServerEndpoint(serverName), ipPreference, cachedDNSInfo, hostNameInCertificate); #else _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); + spnBuffer, instanceNameBuffer, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); #endif // NETFRAMEWORK + + int instanceNameLength = instanceNameBuffer.IndexOf((byte)0); + + instanceName = instanceNameLength < 1 ? Array.Empty() : instanceNameBuffer.Slice(0, instanceNameLength).ToArray(); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 17729a4cc9..29e0308c7f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -931,12 +931,10 @@ internal bool TryProcessHeader() { // All read _partialHeaderBytesRead = 0; - _inBytesPacket = ((int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - (int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; + _inBytesPacket = BinaryPrimitives.ReadUInt16BigEndian(_partialHeaderBuffer.AsSpan(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - _inputHeaderLen; _messageStatus = _partialHeaderBuffer[1]; - _spid = _partialHeaderBuffer[TdsEnums.SPID_OFFSET] << 8 | - _partialHeaderBuffer[TdsEnums.SPID_OFFSET + 1]; + _spid = BinaryPrimitives.ReadUInt16BigEndian(_partialHeaderBuffer.AsSpan(TdsEnums.SPID_OFFSET)); SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); } @@ -972,10 +970,8 @@ internal bool TryProcessHeader() { // normal header processing... _messageStatus = _inBuff[_inBytesUsed + 1]; - _inBytesPacket = (_inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - _spid = _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET + 1]; + _inBytesPacket = BinaryPrimitives.ReadUInt16BigEndian(_inBuff.AsSpan(_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET)) - _inputHeaderLen; + _spid = BinaryPrimitives.ReadUInt16BigEndian(_inBuff.AsSpan(_inBytesUsed + TdsEnums.SPID_OFFSET)); #if !NETFRAMEWORK SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); #endif diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs index a99e5d0303..45eaf3d718 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs @@ -12025,7 +12025,7 @@ internal class Strings { } /// - /// Looks up a localized string similar to Specified type is not registered on the target server.{0}.. + /// Looks up a localized string similar to Specified type is not registered on the target server. {0}.. /// internal static string SQLUDT_InvalidSqlType { get { diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx index 821f8f1aa5..bc65877f2e 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx @@ -3559,7 +3559,7 @@ La propiedad UdtTypeName debe establecerse sólo para los parámetros UDT. - El tipo especificado no está registrado en el servidor de destino.{0}. + El tipo especificado no está registrado en el servidor de destino. {0}. No se permiten parámetros UDT en la cláusula where, salvo que formen parte de la clave principal. @@ -4740,4 +4740,4 @@ El certificado no está disponible al validar el certificado. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx index add941c9ce..9499ee5478 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx @@ -3559,7 +3559,7 @@ È necessario impostare la proprietà UdtTypeName property solo per i parametri UDT. - I tipo specificato non è registrato sul server di destinazione.{0}. + I tipo specificato non è registrato sul server di destinazione. {0}. Parametri UDT non consentiti nella clausola WHERE a meno che non facciano parte della chiave primaria. @@ -4740,4 +4740,4 @@ Certificato non disponibile durante la convalida del certificato. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx index d3d62df164..d38de35c97 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx @@ -3559,7 +3559,7 @@ UdtTypeName 속성은 UDT 매개 변수에 대해서만 설정해야 합니다. - 지정한 유형이 대상 서버에 등록되어 있지 않습니다.{0}. + 지정한 유형이 대상 서버에 등록되어 있지 않습니다. {0}. 기본 키의 일부가 아닌 경우 UDT 매개 변수는 where 절에서 허용되지 않습니다. @@ -4740,4 +4740,4 @@ 인증서의 유효성을 검사하는 동안에는 인증서를 사용할 수 없습니다. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx index 3cb1ae77ec..a807de587f 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx @@ -3559,7 +3559,7 @@ A propriedade UdtTypeName deve ser definida apenas como parâmetros UDT. - O tipo especificado não está registrado no servidor de destino.{0}. + O tipo especificado não está registrado no servidor de destino. {0}. Os parâmetros UDT não são permitidos na cláusula where, a menos que façam parte da chave primária. @@ -4740,4 +4740,4 @@ Certificado não disponível durante a validação do certificado. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx index 3712ece88f..993cbdb8d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx @@ -3559,7 +3559,7 @@ UdtTypeName property must be set only for UDT parameters. - Specified type is not registered on the target server.{0}. + Specified type is not registered on the target server. {0}. UDT parameters not permitted in the where clause unless part of the primary key. @@ -4740,4 +4740,4 @@ Certificate not available while validating the certificate. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx index 24fb60a6e4..0fff1724af 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx @@ -3559,7 +3559,7 @@ Свойство UdtTypeName должно быть установлено только для параметров UDT. - Указанный тип не зарегистрирован на сервере назначения.{0}. + Указанный тип не зарегистрирован на сервере назначения. {0}. Не допускается использование параметров UDT в составе конструкции Where, кроме случаев, когда они являются частью первичного ключа. @@ -4740,4 +4740,4 @@ Сертификат недоступен при проверке сертификата. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs index 4f83a8aeb7..8e5d292aac 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs @@ -29,8 +29,8 @@ internal static class ConnectionHelper private static FieldInfo s_enforcedTimeoutDelayInMilliSeconds = s_tdsParserStateObject.GetField("_enforcedTimeoutDelayInMilliSeconds", BindingFlags.Instance | BindingFlags.NonPublic); private static FieldInfo s_pendingSQLDNSObject = s_sqlInternalConnectionTds.GetField("pendingSQLDNSObject", BindingFlags.Instance | BindingFlags.NonPublic); private static PropertyInfo s_pendingSQLDNS_FQDN = s_SQLDNSInfo.GetProperty("FQDN", BindingFlags.Instance | BindingFlags.Public); - private static PropertyInfo s_pendingSQLDNS_AddrIPv4 = s_SQLDNSInfo.GetProperty("AddrIPv4", BindingFlags.Instance | BindingFlags.Public); - private static PropertyInfo s_pendingSQLDNS_AddrIPv6 = s_SQLDNSInfo.GetProperty("AddrIPv6", BindingFlags.Instance | BindingFlags.Public); + private static PropertyInfo s_pendingSQLDNS_AddrIPv4 = s_SQLDNSInfo.GetProperty("CachedIPv4Address", BindingFlags.Instance | BindingFlags.Public); + private static PropertyInfo s_pendingSQLDNS_AddrIPv6 = s_SQLDNSInfo.GetProperty("CachedIPv6Address", BindingFlags.Instance | BindingFlags.Public); private static PropertyInfo s_pendingSQLDNS_Port = s_SQLDNSInfo.GetProperty("Port", BindingFlags.Instance | BindingFlags.Public); private static PropertyInfo dbConnectionInternalIsTransRoot = s_dbConnectionInternal.GetProperty("IsTransactionRoot", BindingFlags.Instance | BindingFlags.NonPublic); private static PropertyInfo dbConnectionInternalEnlistedTrans = s_sqlInternalConnection.GetProperty("EnlistedTransaction", BindingFlags.Instance | BindingFlags.NonPublic); @@ -112,16 +112,16 @@ public static void SetEnforcedTimeout(this SqlConnection connection, bool enforc /// /// Active connection to extract the requested data /// FQDN, AddrIPv4, AddrIPv6, and Port in sequence - public static Tuple GetSQLDNSInfo(this SqlConnection connection) + public static Tuple GetSQLDNSInfo(this SqlConnection connection) { object internalConnection = GetInternalConnection(connection); VerifyObjectIsInternalConnection(internalConnection); object pendingSQLDNSInfo = s_pendingSQLDNSObject.GetValue(internalConnection); string fqdn = s_pendingSQLDNS_FQDN.GetValue(pendingSQLDNSInfo) as string; - string ipv4 = s_pendingSQLDNS_AddrIPv4.GetValue(pendingSQLDNSInfo) as string; - string ipv6 = s_pendingSQLDNS_AddrIPv6.GetValue(pendingSQLDNSInfo) as string; - string port = s_pendingSQLDNS_Port.GetValue(pendingSQLDNSInfo) as string; - return new Tuple(fqdn, ipv4, ipv6, port); + System.Net.IPAddress ipv4 = s_pendingSQLDNS_AddrIPv4.GetValue(pendingSQLDNSInfo) as System.Net.IPAddress; + System.Net.IPAddress ipv6 = s_pendingSQLDNS_AddrIPv6.GetValue(pendingSQLDNSInfo) as System.Net.IPAddress; + int port = (int)s_pendingSQLDNS_Port.GetValue(pendingSQLDNSInfo); + return new Tuple(fqdn, ipv4, ipv6, port); } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs index 8003660889..d32d90066b 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs @@ -57,7 +57,7 @@ public void ConfigurableIpPreference(string ipPreference) { connection.Open(); Assert.Equal(ConnectionState.Open, connection.State); - Tuple DNSInfo = connection.GetSQLDNSInfo(); + Tuple DNSInfo = connection.GetSQLDNSInfo(); if(ipPreference == CnnPrefIPv4) { Assert.NotNull(DNSInfo.Item2); //IPv4 @@ -95,8 +95,8 @@ private void TestCachedConfigurableIpPreference(string ipPreference, string cnnS SQLFallbackDNSCacheGetDNSInfo.Invoke(SQLFallbackDNSCacheInstance, parameters); var dnsCacheEntry = parameters[1]; - const string AddrIPv4Property = "AddrIPv4"; - const string AddrIPv6Property = "AddrIPv6"; + const string AddrIPv4Property = "CachedIPv6Address"; + const string AddrIPv6Property = "CachedIPv6Address"; const string FQDNProperty = "FQDN"; Assert.NotNull(dnsCacheEntry); diff --git a/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json b/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json index 47d1736127..76f073cb67 100644 --- a/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json +++ b/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json @@ -1,12 +1,12 @@ { - "ConnectionString": "Server=tcp:localhost; Integrated Security=true; Initial Catalog=sqlclient-perf-db;", + "ConnectionString": "Server=tcp:localhost; Integrated Security=true; Initial Catalog=sqlclient-perf-db; Trust Server Certificate=Yes;", "UseManagedSniOnWindows": false, "Benchmarks": { "SqlConnectionRunnerConfig": { "Enabled": true, "LaunchCount": 1, "IterationCount": 50, - "InvocationCount":30, + "InvocationCount": 30, "WarmupCount": 5, "RowCount": 0 }, diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 9c58ec3889..5e1694626e 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -75,7 +75,7 @@ 170.8.0 10.50.1600.1 160.1000.6 - 0.13.2 + 0.13.12 6.0.0 6.0.1