From f297925fe8140ce579efd728cab462e64473a2c0 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Sun, 24 Mar 2024 19:49:27 +0000 Subject: [PATCH 01/20] Eliminated async-over-sync in the SSRP class. This isn't in use yet, but once TCP connection establishment is fully async, it'll be available. Also incremented the version of BenchmarkDotNet (so that the benchmarks can run with .NET 8.0.) --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 49 +++ .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 331 +++++++++++------- tools/props/Versions.props | 2 +- 3 files changed, 249 insertions(+), 133 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index 964d332ae4..caf4af1717 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -350,6 +350,43 @@ internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer ti } } +#if NET6_0_OR_GREATER + internal static async ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, bool async) +#else + internal static ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, bool async) +#endif + { + using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + { + int remainingTimeout = timeout.MillisecondsRemainingInt; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, + "Getting DNS host entries for serverName {0} within {1} milliseconds.", + args0: serverName, + args1: remainingTimeout); + using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); + +#if NET6_0_OR_GREATER + Task task = Dns.GetHostAddressesAsync(serverName, cts.Token); + + if (async) + { + return await task.ConfigureAwait(false); + } + else + { + task.Wait(); + return task.Result; + } +#else + // using this overload to support netstandard + Task task = Dns.GetHostAddressesAsync(serverName); + + task.Wait(cts.Token); + return new ValueTask(task.Result); +#endif + } + } + internal static IPAddress[] GetDnsIpAddresses(string serverName) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) @@ -359,6 +396,18 @@ internal static IPAddress[] GetDnsIpAddresses(string serverName) } } + internal static async ValueTask GetDnsIpAddresses(string serverName, bool async) + { + using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName); + + return async + ? await Dns.GetHostAddressesAsync(serverName) + : Dns.GetHostAddresses(serverName); + } + } + /// /// Sets last error encountered for SNI /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 3cad605caa..81e9c99ad0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -16,6 +16,9 @@ namespace Microsoft.Data.SqlClient.SNI { internal sealed class SSRP { + private static readonly TimeSpan s_sendTimeout = TimeSpan.FromSeconds(1.0); + private static readonly TimeSpan s_receiveTimeout = TimeSpan.FromSeconds(1.0); + private const char SemicolonSeparator = ';'; private const int SqlServerBrowserPort = 1434; //port SQL Server Browser private const int RecieveMAXTimeoutsForCLNT_BCAST_EX = 15000; //Default max time for response wait @@ -28,13 +31,28 @@ internal sealed class SSRP /// /// Finds instance port number for given instance name. /// - /// SQL Sever Browser hostname + /// SQL Server Browser hostname /// instance name to find port number /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// port number for given instance name internal static int GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; + + /// + /// Finds instance port number for given instance name. + /// + /// SQL Server Browser hostname + /// instance name to find port number + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference + /// port number for given instance name + internal static ValueTask GetPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); + + private static async ValueTask GetPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); @@ -44,7 +62,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc byte[] responsePacket = null; try { - responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference); + responsePacket = await SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference, async); } catch (SocketException se) { @@ -103,19 +121,34 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// /// Finds DAC port for given instance name. /// - /// SQL Sever Browser hostname + /// SQL Server Browser hostname /// instance name to lookup DAC port /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; + + /// + /// Finds DAC port for given instance name. + /// + /// SQL Server Browser hostname + /// instance name to lookup DAC port + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference + /// DAC port for given instance name + internal static ValueTask GetDacPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); + + private static async ValueTask GetDacPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference); + byte[] responsePacket = await SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference, async); const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -152,12 +185,6 @@ private static byte[] CreateDacPortInfoRequest(string instanceName) return requestPacket; } - private class SsrpResult - { - public byte[] ResponsePacket; - public Exception Error; - } - /// /// Sends request to server, and receives response from server by UDP. /// @@ -167,8 +194,9 @@ private class SsrpResult /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference + /// If true, this method will be run asynchronously /// response packet from UDP server - private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + private static async ValueTask SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -178,48 +206,56 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re if (IPAddress.TryParse(browserHostname, out IPAddress address)) { - SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel); - if (response != null && response.ResponsePacket != null) - return response.ResponsePacket; - else if (response != null && response.Error != null) - throw response.Error; - else - return null; + return await SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel, async); } - IPAddress[] ipAddresses = timeout.IsInfinite - ? SNICommon.GetDnsIpAddresses(browserHostname) - : SNICommon.GetDnsIpAddresses(browserHostname, timeout); + IPAddress[] ipAddresses = await (timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(browserHostname, async) + : SNICommon.GetDnsIpAddresses(browserHostname, timeout, async)); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); IPAddress[] ipv4Addresses = null; + byte[] response4 = null; + IPAddress[] ipv6Addresses = null; + byte[] response6 = null; + + Exception responseException = null; + switch (ipPreference) { case SqlConnectionIPAddressPreference.IPv4First: { SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); - - SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); - if (response4 != null && response4.ResponsePacket != null) + + try { - return response4.ResponsePacket; + response4 = await SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + + if (response4 != null) + { + return response4; + } } + catch(Exception e) + { responseException ??= e; } - SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); - if (response6 != null && response6.ResponsePacket != null) + try { - return response6.ResponsePacket; + response6 = await SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + + if (response6 != null) + { + return response6; + } } + catch (Exception e) + { responseException ??= e; } // No responses so throw first error - if (response4 != null && response4.Error != null) - { - throw response4.Error; - } - else if (response6 != null && response6.Error != null) + if (responseException != null) { - throw response6.Error; + throw responseException; } break; @@ -228,44 +264,40 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re { SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); - SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); - if (response6 != null && response6.ResponsePacket != null) + try { - return response6.ResponsePacket; + response6 = await SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + + if (response6 != null) + { + return response6; + } } + catch (Exception e) + { responseException ??= e; } - SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); - if (response4 != null && response4.ResponsePacket != null) + try { - return response4.ResponsePacket; + response4 = await SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + + if (response4 != null) + { + return response4; + } } + catch (Exception e) + { responseException ??= e; } // No responses so throw first error - if (response6 != null && response6.Error != null) - { - throw response6.Error; - } - else if (response4 != null && response4.Error != null) + if (responseException != null) { - throw response4.Error; + throw responseException; } break; } default: - { - SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel); - if (response != null && response.ResponsePacket != null) - { - return response.ResponsePacket; - } - else if (response != null && response.Error != null) - { - throw response.Error; - } - - break; - } + return await SendUDPRequest(ipAddresses, port, requestPacket, true, async); // allIPsInParallel); } return null; @@ -279,151 +311,186 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re /// UDP server port /// request packet /// query all resolved IP addresses in parallel + /// If true, this method will be run asynchronously /// response packet from UDP server - private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel) + private static async ValueTask SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) { if (ipAddresses.Length == 0) return null; if (allIPsInParallel) // Used for MultiSubnetFailover { - List> tasks = new(ipAddresses.Length); + List> tasks = new(ipAddresses.Length); + Task firstFailedTask = null; CancellationTokenSource cts = new CancellationTokenSource(); + for (int i = 0; i < ipAddresses.Length; i++) { IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); - tasks.Add(Task.Factory.StartNew(() => SendUDPRequest(endPoint, requestPacket), cts.Token)); + tasks.Add(SendUDPRequest(endPoint, requestPacket, async, cts.Token)); } - List> completedTasks = new(); while (tasks.Count > 0) { - int first = Task.WaitAny(tasks.ToArray()); - if (tasks[first].Result.ResponsePacket != null) + Task completedTask; + + if (async) { - cts.Cancel(); - return tasks[first].Result; + completedTask = await Task.WhenAny(tasks).ConfigureAwait(false); + + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } } else { - completedTasks.Add(tasks[first]); - tasks.Remove(tasks[first]); + int completedTaskIndex = Task.WaitAny(tasks.ToArray()); + + completedTask = tasks[completedTaskIndex]; + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } + } + + if (completedTask.Status == TaskStatus.Faulted) + { + tasks.Remove(completedTask); + firstFailedTask ??= completedTask; } } - Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0"); + Debug.Assert(firstFailedTask != null, "firstFailedTask should never be null"); // All tasks failed. Return the error from the first failure. - return completedTasks[0].Result; + throw firstFailedTask.Exception; } else { // If not parallel, use the first IP address provided IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); - return SendUDPRequest(endPoint, requestPacket); + return await SendUDPRequest(endPoint, requestPacket, async, CancellationToken.None); } } - private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket) +#if NET6_0_OR_GREATER + private static async Task SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket, bool async, CancellationToken token) +#else + private static Task SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket, bool async, CancellationToken token) +#endif { - const int sendTimeOutMs = 1000; - const int receiveTimeOutMs = 1000; - - SsrpResult result = new(); + byte[] responsePacket = null; try { using (UdpClient client = new UdpClient(endPoint.AddressFamily)) { - Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); - Task receiveTask = null; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); - if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs)) + + using (CancellationTokenSource sendTimeoutCancellationTokenSource = new CancellationTokenSource(s_sendTimeout)) + using (CancellationTokenSource sendCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, sendTimeoutCancellationTokenSource.Token)) + { +#if NET6_0_OR_GREATER + ValueTask sendTask = client.SendAsync(requestPacket.AsMemory(), endPoint, sendCancellationTokenSource.Token); + + if (async) + { + await sendTask.ConfigureAwait(false); + } + else + { + if (!sendTask.IsCompleted) + { + sendTask.AsTask().Wait(); + } + } +#else + Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); + + sendTask.Wait(sendCancellationTokenSource.Token); +#endif + } + + UdpReceiveResult receiveResult; + + using (CancellationTokenSource receiveTimeoutCancellationTokenSource = new CancellationTokenSource(s_receiveTimeout)) + using (CancellationTokenSource receiveCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, receiveTimeoutCancellationTokenSource.Token)) { +#if NET6_0_OR_GREATER + ValueTask receiveTask = client.ReceiveAsync(receiveCancellationTokenSource.Token); + + if (async) + { + receiveResult = await receiveTask.ConfigureAwait(false); + } + else + { + if (!receiveTask.IsCompleted) + { + receiveTask.AsTask().Wait(); + } + + receiveResult = receiveTask.Result; + } +#else + Task receiveTask = client.ReceiveAsync(); + + receiveTask.Wait(receiveCancellationTokenSource.Token); + receiveResult = receiveTask.Result; +#endif + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); - result.ResponsePacket = receiveTask.Result.Buffer; + responsePacket = receiveResult.Buffer; } } } + catch (OperationCanceledException) + { + responsePacket = null; + } catch (AggregateException ae) { if (ae.InnerExceptions.Count > 0) { + Exception firstSocketException = null; + // Log all errors foreach (Exception e in ae.InnerExceptions) { // Favor SocketException for returned error if (e is SocketException) { - result.Error = e; + firstSocketException = e; } SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); } - // Return first error if we didn't find a SocketException - result.Error = result.Error == null ? ae.InnerExceptions[0] : result.Error; + // Throw first error if we didn't find a SocketException + throw firstSocketException ?? ae.InnerExceptions[0]; } else { - result.Error = ae; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: ae.Message); + throw; } } catch (Exception e) { - result.Error = e; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); + throw; } - return result; - } - - /// - /// Sends request to server, and recieves response from server (SQLBrowser) on port 1434 by UDP - /// Request (https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/a3035afa-c268-4699-b8fd-4f351e5c8e9e) - /// Response (https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/2e1560c9-5097-4023-9f5e-72b9ff1ec3b1) - /// - /// string constaning list of SVR_RESP(just RESP_DATA) - internal static string SendBroadcastUDPRequest() - { - StringBuilder response = new StringBuilder(); - byte[] CLNT_BCAST_EX_Request = new byte[1] { CLNT_BCAST_EX }; //0x02 - // Waits 5 seconds for the first response and every 1 second up to 15 seconds - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr/f2640a2d-3beb-464b-a443-f635842ebc3e#Appendix_A_3 - int currentTimeOut = FirstTimeoutForCLNT_BCAST_EX; - - using (TrySNIEventScope.Create(nameof(SSRP))) - { - using (UdpClient clientListener = new UdpClient()) - { - Task sendTask = clientListener.SendAsync(CLNT_BCAST_EX_Request, CLNT_BCAST_EX_Request.Length, new IPEndPoint(IPAddress.Broadcast, SqlServerBrowserPort)); - Task receiveTask = null; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch list of instances."); - Stopwatch sw = new Stopwatch(); //for waiting until 15 sec elapsed - sw.Start(); - try - { - while ((receiveTask = clientListener.ReceiveAsync()).Wait(currentTimeOut) && sw.ElapsedMilliseconds <= RecieveMAXTimeoutsForCLNT_BCAST_EX && receiveTask != null) - { - currentTimeOut = RecieveTimeoutsForCLNT_BCAST_EX; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received instnace info from UDP Client."); - if (receiveTask.Result.Buffer.Length < ValidResponseSizeForCLNT_BCAST_EX) //discard invalid response - { - response.Append(Encoding.ASCII.GetString(receiveTask.Result.Buffer, ServerResponseHeaderSizeForCLNT_BCAST_EX, receiveTask.Result.Buffer.Length - ServerResponseHeaderSizeForCLNT_BCAST_EX)); //RESP_DATA(VARIABLE) - 3 (RESP_SIZE + SVR_RESP) - } - } - } - finally - { - sw.Stop(); - } - } - } - return response.ToString(); +#if NET6_0_OR_GREATER + return responsePacket; +#else + return Task.FromResult(responsePacket); +#endif } private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses) diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 9c58ec3889..5e1694626e 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -75,7 +75,7 @@ 170.8.0 10.50.1600.1 160.1000.6 - 0.13.2 + 0.13.12 6.0.0 6.0.1 From a66b8cef546115d6139d692f21aa2db80ca463de Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Sun, 24 Mar 2024 22:45:05 +0000 Subject: [PATCH 02/20] Tweaked SQLFallbackDNSCache to improve memory usage. Made the various fields strongly-typed to eliminate casting/parsing on the hot path. --- .../Interop/SNINativeMethodWrapper.Windows.cs | 12 +- .../Data/SqlClient/SNI/SNITcpHandle.cs | 115 +++++++++++------- .../SqlClient/TdsParserStateObjectNative.cs | 6 +- .../Data/Interop/SNINativeMethodWrapper.cs | 12 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 6 +- .../Data/SqlClient/SQLFallbackDNSCache.cs | 25 ++-- 6 files changed, 104 insertions(+), 72 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index c8591a8c11..b9df5ed439 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -387,9 +387,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } @@ -432,9 +432,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo.Port.ToString(); if (spnBuffer != null) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index d12e91ad62..1f0afd4ac5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -187,21 +187,21 @@ public override int ProtocolVersion } else { - int portRetry = string.IsNullOrEmpty(cachedDNSInfo.Port) ? port : int.Parse(cachedDNSInfo.Port); + int portRetry = cachedDNSInfo.Port == 0 ? port : cachedDNSInfo.Port; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying with cached DNS IP Address {1} and port {2}", args0: _connectionId, args1: cachedDNSInfo.AddrIPv4, args2: cachedDNSInfo.Port); - string firstCachedIP; - string secondCachedIP; + IPAddress[] firstCachedIP; + IPAddress[] secondCachedIP; if (SqlConnectionIPAddressPreference.IPv6First == ipPreference) { - firstCachedIP = cachedDNSInfo.AddrIPv6; - secondCachedIP = cachedDNSInfo.AddrIPv4; + firstCachedIP = new[] { cachedDNSInfo.AddrIPv6 }; + secondCachedIP = new[] { cachedDNSInfo.AddrIPv4 }; } else { - firstCachedIP = cachedDNSInfo.AddrIPv4; - secondCachedIP = cachedDNSInfo.AddrIPv6; + firstCachedIP = new[] { cachedDNSInfo.AddrIPv4 }; + secondCachedIP = new[] { cachedDNSInfo.AddrIPv6 }; } try @@ -300,11 +300,7 @@ public override int ProtocolVersion // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { - Socket availableSocket = null; - Task connectTask; - bool isInfiniteTimeOut = timeout.IsInfinite; - - IPAddress[] serverAddresses = isInfiniteTimeOut + IPAddress[] serverAddresses = timeout.IsInfinite ? SNICommon.GetDnsIpAddresses(hostName) : SNICommon.GetDnsIpAddresses(hostName, timeout); @@ -314,27 +310,35 @@ private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeou callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} serverAddresses.Length {1} Exception: {2}", args0: _connectionId, args1: serverAddresses.Length, args2: Strings.SNI_ERROR_47); ReportTcpSNIError(0, SNICommon.MultiSubnetFailoverWithMoreThan64IPs, Strings.SNI_ERROR_47); - return availableSocket; + return null; } - string IPv4String = null; - string IPv6String = null; + return TryConnectParallel(serverAddresses, port, timeout, ref callerReportError, cachedFQDN, ref pendingDNSInfo); + } + + private Socket TryConnectParallel(IPAddress[] serverAddresses, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + { + Socket availableSocket = null; + Task connectTask; + bool isInfiniteTimeOut = timeout.IsInfinite; + IPAddress ipv4Address = null; + IPAddress ipv6Address = null; foreach (IPAddress ipAddress in serverAddresses) { if (ipAddress.AddressFamily == AddressFamily.InterNetwork) { - IPv4String = ipAddress.ToString(); + ipv4Address = ipAddress; } else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { - IPv6String = ipAddress.ToString(); + ipv6Address = ipAddress; } } - if (IPv4String != null || IPv6String != null) + if (ipv4Address != null || ipv6Address != null) { - pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString()); + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, ipv4Address, ipv6Address, port); } connectTask = ParallelConnectAsync(serverAddresses, port); @@ -355,9 +359,10 @@ private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeou /// Returns array of IP addresses for the given server name, sorted according to the given preference. /// /// Thrown when ipPreference is not supported - private static IEnumerable GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference) + private static IPAddress[] GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference) { - IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); + IPAddress[] dnsIPAddresses = Dns.GetHostAddresses(serverName); + IPAddress[] ipAddresses; AddressFamily? prioritiesFamily = ipPreference switch { SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, @@ -366,41 +371,61 @@ private static IEnumerable GetHostAddressesSortedByPreference(string _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference)) }; - // Return addresses of the preferred family first - if (prioritiesFamily != null) + if (prioritiesFamily == null) + { + ipAddresses = dnsIPAddresses; + } + else { - foreach (IPAddress ipAddress in ipAddresses) + int resultArrayIndex = 0; + + ipAddresses = new IPAddress[dnsIPAddresses.Length]; + + // Return addresses of the preferred family first + for (int i = 0; i < dnsIPAddresses.Length; i++) { - if (ipAddress.AddressFamily == prioritiesFamily) + if (dnsIPAddresses[i].AddressFamily == prioritiesFamily) { - yield return ipAddress; + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; } } - } - // Return addresses of the other family - foreach (IPAddress ipAddress in ipAddresses) - { - if (ipAddress.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6) + // Return addresses of the other family + for (int i = 0; i < dnsIPAddresses.Length; i++) { - if (prioritiesFamily == null || ipAddress.AddressFamily != prioritiesFamily) + if (dnsIPAddresses[i].AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 + && dnsIPAddresses[i].AddressFamily != prioritiesFamily) { - yield return ipAddress; + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; } } + + // If the DNS resolution returned records of types other than A and AAAA, the original array size will be + // too large, and must thus be resized. This is very unlikely, so we only try to do this post-hoc. + if (resultArrayIndex + 1 < ipAddresses.Length) + { + Array.Resize(ref ipAddresses, resultArrayIndex + 1); + } } + + return ipAddresses; } // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private static Socket Connect(string serverName, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + { + IPAddress[] ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); + + return Connect(ipAddresses, port, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); + } + + private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); bool isInfiniteTimeout = timeout.IsInfinite; - IEnumerable ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); - foreach (IPAddress ipAddress in ipAddresses) { bool isSocketSelected = false; @@ -436,8 +461,12 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, { return null; } +#if NET6_0_OR_GREATER + Task socketConnectTask = socket.ConnectAsync(ipAddress, port); +#else // Socket.Connect does not support infinite timeouts, so we use Task to simulate it Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); +#endif socketConnectTask.ConfigureAwait(false); socketConnectTask.Start(); int remainingTimeout = timeout.MillisecondsRemainingInt; @@ -494,17 +523,21 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout, if (isConnected) { socket.Blocking = true; - string iPv4String = null; - string iPv6String = null; - if (socket.AddressFamily == AddressFamily.InterNetwork) + + IPAddress ipv4Address = null; + IPAddress ipv6Address = null; + + if (ipAddress.AddressFamily == AddressFamily.InterNetwork) { - iPv4String = ipAddress.ToString(); + ipv4Address = ipAddress; } else { - iPv6String = ipAddress.ToString(); + ipv6Address = ipAddress; } - pendingDNSInfo = new SQLDNSInfo(cachedFQDN, iPv4String, iPv6String, port.ToString()); + + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, ipv4Address, ipv6Address, port); + isSocketSelected = true; return socket; } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 59776956a1..7104b22ce6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -101,17 +101,17 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache result = SNINativeMethodWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); - pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI); if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv4 = IPStringFromSNI; + pendingDNSInfo.AddrIPv4 = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv6 = IPStringFromSNI; + pendingDNSInfo.AddrIPv6 = IPFromSNI; } } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index b9af5849c4..9fc2b4343e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -1095,9 +1095,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } @@ -1154,9 +1154,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); if (spnBuffer != null) { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 7a9bbfdfd3..2170d02592 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -854,17 +854,17 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) result = SNINativeMethodWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); - _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI); if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv4 = IPStringFromSNI; + _connHandler.pendingSQLDNSObject.AddrIPv4 = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv6 = IPStringFromSNI; + _connHandler.pendingSQLDNSObject.AddrIPv6 = IPFromSNI; } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs index 9d4136d01f..6ce612080e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; +using System.Net; namespace Microsoft.Data.SqlClient { @@ -28,13 +29,12 @@ internal bool AddDNSInfo(SQLDNSInfo item) { if (null != item) { - if (DNSInfoCache.ContainsKey(item.FQDN)) - { - - DeleteDNSInfo(item.FQDN); - } - - return DNSInfoCache.TryAdd(item.FQDN, item); +#if NET6_0_OR_GREATER || NETSTANDARD2_1 + DNSInfoCache.AddOrUpdate(item.FQDN, static (key, state) => state, static (key, value, state) => state, item); +#else + DNSInfoCache.AddOrUpdate(item.FQDN, item, (key, value) => item); +#endif + return true; } return false; @@ -42,8 +42,7 @@ internal bool AddDNSInfo(SQLDNSInfo item) internal bool DeleteDNSInfo(string FQDN) { - SQLDNSInfo value; - return DNSInfoCache.TryRemove(FQDN, out value); + return DNSInfoCache.TryRemove(FQDN, out _); } internal bool GetDNSInfo(string FQDN, out SQLDNSInfo result) @@ -71,11 +70,11 @@ internal bool IsDuplicate(SQLDNSInfo newItem) internal sealed class SQLDNSInfo { public string FQDN { get; set; } - public string AddrIPv4 { get; set; } - public string AddrIPv6 { get; set; } - public string Port { get; set; } + public IPAddress AddrIPv4 { get; set; } + public IPAddress AddrIPv6 { get; set; } + public int Port { get; set; } - internal SQLDNSInfo(string FQDN, string ipv4, string ipv6, string port) + internal SQLDNSInfo(string FQDN, IPAddress ipv4, IPAddress ipv6, int port) { this.FQDN = FQDN; AddrIPv4 = ipv4; From bcbe8b3f0f3b2cf1c37803084d321ef3dc77fb8d Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:42:01 +0000 Subject: [PATCH 03/20] Reworked ParallelConnectAsync slightly. This removes an explicit task continuation, slightly reduces memory usage and makes use of the newer .NET 6.0+ APIs to simplify socket connection cancellation --- .../Data/SqlClient/SNI/SNITcpHandle.cs | 147 ++++++++++-------- 1 file changed, 82 insertions(+), 65 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 1f0afd4ac5..908a1dca5b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -343,7 +343,7 @@ private Socket TryConnectParallel(IPAddress[] serverAddresses, int port, Timeout connectTask = ParallelConnectAsync(serverAddresses, port); - if (!(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) + if (connectTask.Status != TaskStatus.RanToCompletion && !(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) { callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40); @@ -425,12 +425,22 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); bool isInfiniteTimeout = timeout.IsInfinite; + IPEndPoint ipEndPoint = null; foreach (IPAddress ipAddress in ipAddresses) { bool isSocketSelected = false; Socket socket = null; + if (ipEndPoint == null) + { + ipEndPoint = new IPEndPoint(ipAddress, port); + } + else + { + ipEndPoint.Address = ipAddress; + } + try { socket = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) @@ -453,7 +463,7 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti { if (isInfiniteTimeout) { - socket.Connect(ipAddress, port); + socket.Connect(ipEndPoint); } else { @@ -462,10 +472,10 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti return null; } #if NET6_0_OR_GREATER - Task socketConnectTask = socket.ConnectAsync(ipAddress, port); + Task socketConnectTask = socket.ConnectAsync(ipEndPoint); #else // Socket.Connect does not support infinite timeouts, so we use Task to simulate it - Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); + Task socketConnectTask = new Task(() => socket.Connect(ipEndPoint)); #endif socketConnectTask.ConfigureAwait(false); socketConnectTask.Start(); @@ -558,7 +568,7 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti return null; } - private static Task ParallelConnectAsync(IPAddress[] serverAddresses, int port) + private static async Task ParallelConnectAsync(IPAddress[] serverAddresses, int port) { if (serverAddresses == null) { @@ -569,93 +579,100 @@ private static Task ParallelConnectAsync(IPAddress[] serverAddresses, in throw new ArgumentOutOfRangeException(nameof(serverAddresses)); } - var sockets = new List(serverAddresses.Length); - var connectTasks = new List(serverAddresses.Length); - var tcs = new TaskCompletionSource(); - var lastError = new StrongBox(); - var pendingCompleteCount = new StrongBox(serverAddresses.Length); + using var connectCancellationTokenSource = new CancellationTokenSource(); + Exception lastException = null; + IPEndPoint ipEndPoint = null; + + List emptySocketList = new List(); + List socketErrorCheckList = new List(1); + Dictionary socketConnectionTasks = new(serverAddresses.Length); + Socket completedSocket; foreach (IPAddress address in serverAddresses) { var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - sockets.Add(socket); + + if (ipEndPoint == null) + { + ipEndPoint = new IPEndPoint(address, port); + } + else + { + ipEndPoint.Address = address; + } // Start all connection tasks now, to prevent possible race conditions with // calling ConnectAsync on disposed sockets. try { - connectTasks.Add(socket.ConnectAsync(address, port)); +#if NET6_0_OR_GREATER + socketConnectionTasks.Add(socket.ConnectAsync(ipEndPoint, connectCancellationTokenSource.Token).AsTask(), socket); +#else + socketConnectionTasks.Add(socket.ConnectAsync(ipEndPoint), socket); +#endif } catch (Exception e) { - connectTasks.Add(Task.FromException(e)); + socketConnectionTasks.Add(Task.FromException(e), socket); } } - for (int i = 0; i < sockets.Count; i++) - { - ParallelConnectHelper(sockets[i], connectTasks[i], tcs, pendingCompleteCount, lastError, sockets); - } - - return tcs.Task; - } - - private static async void ParallelConnectHelper( - Socket socket, - Task connectTask, - TaskCompletionSource tcs, - StrongBox pendingCompleteCount, - StrongBox lastError, - List sockets) - { - bool success = false; try { - // Try to connect. If we're successful, store this task into the result task. - await connectTask.ConfigureAwait(false); - success = tcs.TrySetResult(socket); - if (success) + while (socketConnectionTasks.Count > 0) { - // Whichever connection completes the return task is responsible for disposing - // all of the sockets (except for whichever one is stored into the result task). - // This ensures that only one thread will attempt to dispose of a socket. - // This is also the closest thing we have to canceling connect attempts. - foreach (Socket otherSocket in sockets) + Task completedTask = await Task.WhenAny(socketConnectionTasks.Keys).ConfigureAwait(false); + Socket taskSocket = socketConnectionTasks[completedTask]; + + if (completedTask.Status == TaskStatus.RanToCompletion) { - if (otherSocket != socket) + // workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538 + if (socketErrorCheckList.Count > 0) { - otherSocket.Dispose(); + socketErrorCheckList.Clear(); + } + socketErrorCheckList.Add(taskSocket); + + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Determining the status of the socket following completion of ConnectAsync."); + + Socket.Select(emptySocketList, emptySocketList, socketErrorCheckList, 0); + if (taskSocket.Connected && socketErrorCheckList.Count == 0) + { + + completedSocket = taskSocket; + connectCancellationTokenSource.Cancel(); + socketConnectionTasks.Remove(completedTask); + + lastException = null; + return completedSocket; + } + } + else + { + if (completedTask.Status == TaskStatus.Faulted) + { + lastException = completedTask.Exception; } } + + taskSocket.Dispose(); + socketConnectionTasks.Remove(completedTask); } - } - catch (Exception e) - { - // Store an exception to be published if no connection succeeds - Interlocked.Exchange(ref lastError.Value, e); + + if (lastException != null) + throw lastException; + + connectCancellationTokenSource.Token.ThrowIfCancellationRequested(); + // This return statement can never be reached. socketConnectionsTasks would need to have been drained, but this can only happen if all tasks + // inside it have succeeded, failed or have been cancelled. Success will result in an early return, failure results in a thrown exception, + // cancellation results in another exception. + return null; } finally { - // If we didn't successfully transition the result task to completed, - // then someone else did and they would have cleaned up, so there's nothing - // more to do. Otherwise, no one completed it yet or we failed; either way, - // see if we're the last outstanding connection, and if we are, try to complete - // the task, and if we're successful, it's our responsibility to dispose all of the sockets. - if (!success && Interlocked.Decrement(ref pendingCompleteCount.Value) == 0) + foreach (KeyValuePair socketConnectionTaskMapping in socketConnectionTasks) { - if (lastError.Value != null) - { - tcs.TrySetException(lastError.Value); - } - else - { - tcs.TrySetCanceled(); - } - - foreach (Socket s in sockets) - { - s.Dispose(); - } + socketConnectionTaskMapping.Value.Dispose(); } } } From 9666b5d1e5c423e3ca25f43585b794980046e877 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Tue, 26 Mar 2024 22:51:42 +0000 Subject: [PATCH 04/20] Rewrote the generation of pre-login packets, eliminating byte array allocation --- .../Interop/SNINativeMethodWrapper.Windows.cs | 4 +- .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 4 - .../src/Microsoft/Data/SqlClient/TdsParser.cs | 151 ++++++++---------- .../SqlClient/TdsParserStateObjectManaged.cs | 5 +- .../Data/Interop/SNINativeMethodWrapper.cs | 4 +- .../SqlClient/TdsParserSafeHandles.Windows.cs | 11 +- 6 files changed, 81 insertions(+), 98 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index b9df5ed439..568c9727d0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -399,7 +399,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string constring, ref IntPtr pConn, byte[] spnBuffer, - byte[] instanceName, + Span instanceName, bool fOverrideCache, bool fSync, int timeout, @@ -409,7 +409,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string hostNameInCertificate) { - fixed (byte* pin_instanceName = &instanceName[0]) + fixed (byte* pin_instanceName = instanceName) { SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO(); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index d39f382bd6..a0e1310b94 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -132,7 +132,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// /// Full server name from connection string /// Timer expiration - /// Instance name /// SPN /// pre-defined SPN /// Flush packet cache @@ -149,7 +148,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) internal static SNIHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, - out byte[] instanceName, ref byte[][] spnBuffer, string serverSPN, bool flushCache, @@ -163,8 +161,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) string hostNameInCertificate, string serverCertificateFilename) { - instanceName = new byte[1]; - bool errorWithLocalDBProcessing; string localDBDataSource = GetLocalDBDataSource(fullServerName, out errorWithLocalDBProcessing); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 7e04f6b9c5..ba43282666 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -729,6 +729,14 @@ internal void PutSession(TdsParserStateObject session) // prelogin so that we don't try to negotiate encryption again during ConsumePreLoginHandshake. _encryptionOption = EncryptionOptions.NOT_SUP; } + else if (encrypt == SqlConnectionEncryptOption.Mandatory) + { + _encryptionOption = EncryptionOptions.ON; + } + else + { + _encryptionOption = EncryptionOptions.OFF; + } // PreLoginHandshake buffer consists of: // 1) Standard header, with type = MT_PRELOGIN @@ -742,126 +750,94 @@ internal void PutSession(TdsParserStateObject session) // Initialize option offset into payload buffer // 5 bytes for each option (1 byte length, 2 byte offset, 2 byte payload length) - int offset = (int)PreLoginOptions.NUMOPT * 5 + 1; + ushort headerOffset = 0; + ushort headerLength = (ushort)PreLoginOptions.NUMOPT * 5; - byte[] payload = new byte[(int)PreLoginOptions.NUMOPT * 5 + TdsEnums.MAX_PRELOGIN_PAYLOAD_LENGTH]; - int payloadLength = 0; + ushort payloadStart = (ushort)(headerLength + 1); + ushort payloadOffset = payloadStart; + // The payload length is static for each connection string. The lengths of each option are well-known + // Version: 6 bytes; Encryption: 1 byte; Instance: (instanceName.Length + 1) bytes; + // Thread ID: 4 bytes; MARS enablement: 1 byte; Trace: (2 * GUID + 1 * uint); Federated Authentication: 1 byte + // End-of-payload marker: 1 byte + // .NET Core uses a static zero-length instance name, which suggests a 51-byte payload. + // .NET Framework allows a variable-length instance name (up to 254 bytes). This is up to 305 bytes at most. + int payloadLength = 6 + 1 + (instanceName.Length + 1) + 4 + 1 + (GUID_SIZE + GUID_SIZE + 4) + 1 + 1; + + int totalBufferLength = headerLength + payloadLength; + Span preLoginPacketBuffer = stackalloc byte[totalBufferLength]; - for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++) + for (byte option = (byte)PreLoginOptions.VERSION; option < (byte)PreLoginOptions.NUMOPT; option++) { - int optionDataSize = 0; + ushort optionDataSize = 0; + // Structure header: // Fill in the option - _physicalStateObj.WriteByte((byte)option); + preLoginPacketBuffer[headerOffset] = option; // Fill in the offset of the option data - _physicalStateObj.WriteByte((byte)((offset & 0xff00) >> 8)); // send upper order byte - _physicalStateObj.WriteByte((byte)(offset & 0x00ff)); // send lower order byte + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1), payloadOffset); switch (option) { - case (int)PreLoginOptions.VERSION: + case (byte)PreLoginOptions.VERSION: Version systemDataVersion = ADP.GetAssemblyVersion(); // Major and minor - payload[payloadLength++] = (byte)(systemDataVersion.Major & 0xff); - payload[payloadLength++] = (byte)(systemDataVersion.Minor & 0xff); + preLoginPacketBuffer[payloadOffset] = (byte)(systemDataVersion.Major & 0xff); + preLoginPacketBuffer[payloadOffset + 1] = (byte)(systemDataVersion.Minor & 0xff); // Build (Big Endian) - payload[payloadLength++] = (byte)((systemDataVersion.Build & 0xff00) >> 8); - payload[payloadLength++] = (byte)(systemDataVersion.Build & 0xff); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(payloadOffset + 2), (ushort)(systemDataVersion.Build & 0xFFFF)); // Sub-build (Little Endian) - payload[payloadLength++] = (byte)(systemDataVersion.Revision & 0xff); - payload[payloadLength++] = (byte)((systemDataVersion.Revision & 0xff00) >> 8); - offset += 6; + BinaryPrimitives.WriteUInt16LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + 2 + sizeof(ushort)), (ushort)(systemDataVersion.Revision & 0xFFFF)); + optionDataSize = 6; break; - case (int)PreLoginOptions.ENCRYPT: - if (_encryptionOption == EncryptionOptions.NOT_SUP) - { - //If OS doesn't support encryption and encryption is not required, inform server "not supported" by client. - payload[payloadLength] = (byte)EncryptionOptions.NOT_SUP; - } - else - { - // Else, inform server of user request. - if (encrypt == SqlConnectionEncryptOption.Mandatory) - { - payload[payloadLength] = (byte)EncryptionOptions.ON; - _encryptionOption = EncryptionOptions.ON; - } - else - { - payload[payloadLength] = (byte)EncryptionOptions.OFF; - _encryptionOption = EncryptionOptions.OFF; - } - } + case (byte)PreLoginOptions.ENCRYPT: + preLoginPacketBuffer[payloadOffset] = (byte)_encryptionOption; - payloadLength += 1; - offset += 1; optionDataSize = 1; break; - case (int)PreLoginOptions.INSTANCE: - int i = 0; - - while (instanceName[i] != 0) - { - payload[payloadLength] = instanceName[i]; - payloadLength++; - i++; - } + case (byte)PreLoginOptions.INSTANCE: + instanceName.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); + preLoginPacketBuffer[payloadOffset + instanceName.Length] = 0; - payload[payloadLength] = 0; // null terminate - payloadLength++; - i++; - - offset += i; - optionDataSize = i; + optionDataSize = (ushort)(instanceName.Length + 1); break; - case (int)PreLoginOptions.THREADID: + case (byte)PreLoginOptions.THREADID: int threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); - payload[payloadLength++] = (byte)((0xff000000 & threadID) >> 24); - payload[payloadLength++] = (byte)((0x00ff0000 & threadID) >> 16); - payload[payloadLength++] = (byte)((0x0000ff00 & threadID) >> 8); - payload[payloadLength++] = (byte)(0x000000ff & threadID); - offset += 4; + BinaryPrimitives.WriteInt32BigEndian(preLoginPacketBuffer.Slice(payloadOffset), threadID); + optionDataSize = 4; break; - case (int)PreLoginOptions.MARS: - payload[payloadLength++] = (byte)(_fMARS ? 1 : 0); - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.MARS: + preLoginPacketBuffer[payloadOffset] = (byte)(_fMARS ? 1 : 0); + + optionDataSize = 1; break; - case (int)PreLoginOptions.TRACEID: - FillGuidBytes(_connHandler._clientConnectionId, payload.AsSpan(payloadLength, GUID_SIZE)); - payloadLength += GUID_SIZE; - offset += GUID_SIZE; - optionDataSize = GUID_SIZE; + case (byte)PreLoginOptions.TRACEID: + FillGuidBytes(_connHandler._clientConnectionId, preLoginPacketBuffer.Slice(payloadOffset)); ActivityCorrelator.ActivityId actId = ActivityCorrelator.Next(); - FillGuidBytes(actId.Id, payload.AsSpan(payloadLength, GUID_SIZE)); - payloadLength += GUID_SIZE; - payload[payloadLength++] = (byte)(0x000000ff & actId.Sequence); - payload[payloadLength++] = (byte)((0x0000ff00 & actId.Sequence) >> 8); - payload[payloadLength++] = (byte)((0x00ff0000 & actId.Sequence) >> 16); - payload[payloadLength++] = (byte)((0xff000000 & actId.Sequence) >> 24); - int actIdSize = GUID_SIZE + sizeof(uint); - offset += actIdSize; - optionDataSize += actIdSize; + + FillGuidBytes(actId.Id, preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE)); + BinaryPrimitives.WriteUInt32LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE + GUID_SIZE), actId.Sequence); + + optionDataSize = GUID_SIZE + GUID_SIZE + sizeof(uint); SqlClientEventSource.Log.TryTraceEvent(" ClientConnectionID {0}, ActivityID {1}", _connHandler?._clientConnectionId, actId); break; - case (int)PreLoginOptions.FEDAUTHREQUIRED: - payload[payloadLength++] = 0x01; - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.FEDAUTHREQUIRED: + preLoginPacketBuffer[payloadOffset] = 0x01; + + optionDataSize = 1; break; default: @@ -869,16 +845,19 @@ internal void PutSession(TdsParserStateObject session) break; } + payloadOffset += optionDataSize; + // Write data length - _physicalStateObj.WriteByte((byte)((optionDataSize & 0xff00) >> 8)); - _physicalStateObj.WriteByte((byte)(optionDataSize & 0x00ff)); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1 + sizeof(ushort)), optionDataSize); + + headerOffset += 1 + sizeof(ushort) + sizeof(ushort); } // Write out last option - to let server know the second part of packet completed - _physicalStateObj.WriteByte((byte)PreLoginOptions.LASTOPT); + preLoginPacketBuffer[headerOffset] = (byte)PreLoginOptions.LASTOPT; - // Write out payload - _physicalStateObj.WriteByteArray(payload, payloadLength, 0); + // Write out the full byte buffer + _physicalStateObj.WriteByteSpan(preLoginPacketBuffer); // Flush packet _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 1e0141dd58..71d36bf586 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -19,6 +19,7 @@ namespace Microsoft.Data.SqlClient.SNI { internal sealed class TdsParserStateObjectManaged : TdsParserStateObject { + private static readonly byte[] s_staticInstanceName = Array.Empty(); private SNIMarsConnection? _marsConnection; private SNIHandle? _sessionHandle; #if NET7_0_OR_GREATER @@ -98,10 +99,12 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, ref spnBuffer, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); + instanceName = s_staticInstanceName; + if (sessionHandle is not null) { _sessionHandle = sessionHandle; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index 9fc2b4343e..9c371692ca 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -1107,7 +1107,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan string constring, ref IntPtr pConn, byte[] spnBuffer, - byte[] instanceName, + Span instanceName, bool fOverrideCache, bool fSync, int timeout, @@ -1119,7 +1119,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SQLDNSInfo cachedDNSInfo, string hostNameInCertificate) { - fixed (byte* pin_instanceName = &instanceName[0]) + fixed (byte* pin_instanceName = instanceName) { SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO(); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index 8d276cd285..9ca1ec4def 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -164,6 +164,8 @@ internal sealed class SNIHandle : SafeHandle string hostNameInCertificate) : base(IntPtr.Zero, true) { + Span instanceNameBuffer = stackalloc byte[256]; + #if !NET6_0_OR_GREATER RuntimeHelpers.PrepareConstrainedRegions(); #endif @@ -172,7 +174,6 @@ internal sealed class SNIHandle : SafeHandle finally { _fSync = fSync; - instanceName = new byte[256]; // Size as specified by netlibs. // Option ignoreSniOpenTimeout is no longer available //if (ignoreSniOpenTimeout) //{ @@ -185,12 +186,16 @@ internal sealed class SNIHandle : SafeHandle #if NETFRAMEWORK int transparentNetworkResolutionStateNo = (int)transparentNetworkResolutionState; _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, + spnBuffer, instanceNameBuffer, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, ADP.IsAzureSqlServerEndpoint(serverName), ipPreference, cachedDNSInfo, hostNameInCertificate); #else _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); + spnBuffer, instanceNameBuffer, flushCache, fSync, timeout, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); #endif // NETFRAMEWORK + + int instanceNameLength = instanceNameBuffer.IndexOf((byte)0); + + instanceName = instanceNameLength < 1 ? Array.Empty() : instanceNameBuffer.Slice(0, instanceNameLength).ToArray(); } } From 477c65d7e4420afc738fac4ecbc42fd9199338f4 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:30:16 +0000 Subject: [PATCH 05/20] Two larger changes: * Forcing the flush of all streams before they get switched over to SSL/TLS streams. * Forcing the task sending the pre-login packet to be sent before we try to process the response. Minor tweaks here: * Removed two copies from SSRP. * Switched a few instances of manual bit-shifting to .NET intrinsics. --- .../Data/SqlClient/SNI/SNITcpHandle.cs | 55 ++++++++++++---- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 38 +++++------ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 66 ++++++++----------- .../Data/SqlClient/TdsParserStateObject.cs | 12 ++-- 4 files changed, 88 insertions(+), 83 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 908a1dca5b..37ad7217cd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -458,7 +458,10 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti ipAddress.AddressFamily, isInfiniteTimeout); + CancellationTokenSource timeoutConnectionCancellationTokenSource = null; + int remainingTimeout = timeout.MillisecondsRemainingInt; bool isConnected; + try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select { if (isInfiniteTimeout) @@ -471,34 +474,50 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti { return null; } + #if NET6_0_OR_GREATER - Task socketConnectTask = socket.ConnectAsync(ipEndPoint); + timeoutConnectionCancellationTokenSource = new CancellationTokenSource(remainingTimeout); + ValueTask socketConnectValueTask = socket.ConnectAsync(ipEndPoint, timeoutConnectionCancellationTokenSource.Token); + + if (! socketConnectValueTask.IsCompleted) + { + try + { + socketConnectValueTask.AsTask().Wait(); + } + catch (OperationCanceledException) + { + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); + } + } + #else // Socket.Connect does not support infinite timeouts, so we use Task to simulate it Task socketConnectTask = new Task(() => socket.Connect(ipEndPoint)); -#endif socketConnectTask.ConfigureAwait(false); socketConnectTask.Start(); - int remainingTimeout = timeout.MillisecondsRemainingInt; + if (!socketConnectTask.Wait(remainingTimeout)) { + timeoutConnectionCancellationTokenSource.Cancel(); throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); } + throw SQL.SocketDidNotThrow(); +#endif } isConnected = true; } - catch (AggregateException aggregateException) when (!isInfiniteTimeout - && aggregateException.InnerException is SocketException socketException - && socketException.SocketErrorCode == SocketError.WouldBlock) + catch (SocketException socketException) when (!isInfiniteTimeout + && socketException.SocketErrorCode == SocketError.WouldBlock) { // https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118 // Socket.Select is used because it supports timeouts, while Socket.Connect does not - List checkReadLst; - List checkWriteLst; - List checkErrorLst; + List checkReadLst = new (1); + List checkWriteLst = new(1); + List checkErrorLst = new(1); // Repeating Socket.Select several times if our timeout is greater // than int.MaxValue microseconds because of @@ -508,15 +527,16 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti { if (timeout.IsExpired) { - return null; + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); + //return null; } int socketSelectTimeout = checked((int)(Math.Min(timeout.MillisecondsRemainingInt, int.MaxValue / 1000) * 1000)); - checkReadLst = new List(1) { socket }; - checkWriteLst = new List(1) { socket }; - checkErrorLst = new List(1) { socket }; + checkReadLst.Add(socket); + checkWriteLst.Add(socket); + checkErrorLst.Add(socket); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Determining the status of the socket during the remaining timeout of {0} microseconds.", @@ -529,6 +549,10 @@ private static Socket Connect(IPAddress[] ipAddresses, int port, TimeoutTimer ti // workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538 isConnected = socket.Connected && checkErrorLst.Count == 0; } + finally + { + timeoutConnectionCancellationTokenSource?.Dispose(); + } if (isConnected) { @@ -697,9 +721,11 @@ public override uint EnableSsl(uint options) // TODO: Resolve whether to send _serverNameIndication or _targetServer. _serverNameIndication currently results in error. Why? _sslStream.AuthenticateAsClient(_targetServer, null, s_supportedProtocols, false); } + _sslStream.Flush(); if (_sslOverTdsStream is not null) { _sslOverTdsStream.FinishHandshake(); + _sslOverTdsStream.Flush(); } } catch (AuthenticationException aue) @@ -724,11 +750,14 @@ public override uint EnableSsl(uint options) /// public override void DisableSsl() { + _sslStream.Flush(); _sslStream.Dispose(); _sslStream = null; + _sslOverTdsStream?.Flush(); _sslOverTdsStream?.Dispose(); _sslOverTdsStream = null; _stream = _tcpStream; + _stream.Flush(); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, SSL Disabled. Communication will continue on TCP Stream.", args0: _connectionId); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 81e9c99ad0..8122906ae3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -16,6 +16,7 @@ namespace Microsoft.Data.SqlClient.SNI { internal sealed class SSRP { + private static readonly List s_emptyList = new(0); private static readonly TimeSpan s_sendTimeout = TimeSpan.FromSeconds(1.0); private static readonly TimeSpan s_receiveTimeout = TimeSpan.FromSeconds(1.0); @@ -214,10 +215,10 @@ private static async ValueTask SendUDPRequest(string browserHostname, in : SNICommon.GetDnsIpAddresses(browserHostname, timeout, async)); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); - IPAddress[] ipv4Addresses = null; + List ipv4Addresses = null; byte[] response4 = null; - IPAddress[] ipv6Addresses = null; + List ipv6Addresses = null; byte[] response6 = null; Exception responseException = null; @@ -297,7 +298,7 @@ private static async ValueTask SendUDPRequest(string browserHostname, in break; } default: - return await SendUDPRequest(ipAddresses, port, requestPacket, true, async); // allIPsInParallel); + return await SendUDPRequest(ipAddresses, port, requestPacket, true, async).ConfigureAwait(false); // allIPsInParallel); } return null; @@ -313,18 +314,18 @@ private static async ValueTask SendUDPRequest(string browserHostname, in /// query all resolved IP addresses in parallel /// If true, this method will be run asynchronously /// response packet from UDP server - private static async ValueTask SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) + private static async ValueTask SendUDPRequest(IList ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) { - if (ipAddresses.Length == 0) + if (ipAddresses.Count == 0) return null; if (allIPsInParallel) // Used for MultiSubnetFailover { - List> tasks = new(ipAddresses.Length); + List> tasks = new(ipAddresses.Count); Task firstFailedTask = null; CancellationTokenSource cts = new CancellationTokenSource(); - for (int i = 0; i < ipAddresses.Length; i++) + for (int i = 0; i < ipAddresses.Count; i++) { IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); tasks.Add(SendUDPRequest(endPoint, requestPacket, async, cts.Token)); @@ -493,15 +494,15 @@ private static Task SendUDPRequest(IPEndPoint endPoint, byte[] requestPa #endif } - private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses) + private static void SplitIPv4AndIPv6(IPAddress[] input, out List ipv4Addresses, out List ipv6Addresses) { - ipv4Addresses = Array.Empty(); - ipv6Addresses = Array.Empty(); + List v4 = null; + List v6 = null; if (input != null && input.Length > 0) { - List v4 = new List(1); - List v6 = new List(0); + v4 = new List(1); + v6 = new List(0); for (int index = 0; index < input.Length; index++) { @@ -515,17 +516,10 @@ private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addr break; } } - - if (v4.Count > 0) - { - ipv4Addresses = v4.ToArray(); - } - - if (v6.Count > 0) - { - ipv6Addresses = v6.ToArray(); - } } + + ipv4Addresses = v4 ?? s_emptyList; + ipv6Addresses = v6 ?? s_emptyList; } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index ba43282666..b2f76135bd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -860,7 +860,7 @@ internal void PutSession(TdsParserStateObject session) _physicalStateObj.WriteByteSpan(preLoginPacketBuffer); // Flush packet - _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); + _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH)?.Wait(); } private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integratedSecurity, string serverCertificateFilename) @@ -942,7 +942,11 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ { throw SQL.ParsingError(); } - byte[] payload = new byte[_physicalStateObj._inBytesPacket]; + // Most of the time, this response packet will be very small (less than 512 bytes.) + // In such a situation, borrow stack space rather than requesting an array. + Span payload = _physicalStateObj._inBytesPacket < 512 + ? stackalloc byte[_physicalStateObj._inBytesPacket] + : new byte[_physicalStateObj._inBytesPacket]; Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); result = _physicalStateObj.TryReadByteArray(payload, payload.Length); @@ -958,24 +962,22 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ throw SQL.InvalidSQLServerVersionUnknown(); } - int offset = 0; - int payloadOffset = 0; - int payloadLength = 0; - int option = payload[offset++]; + int headerOffset = 0; + ushort payloadOffset = 0; + ushort payloadLength = 0; + byte option = payload[headerOffset++]; bool serverSupportsEncryption = false; while (option != (byte)PreLoginOptions.LASTOPT) { + payloadOffset = BinaryPrimitives.ReadUInt16BigEndian(payload.Slice(headerOffset)); + payloadLength = BinaryPrimitives.ReadUInt16BigEndian(payload.Slice(headerOffset + 2)); + switch (option) { - case (int)PreLoginOptions.VERSION: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - + case (byte)PreLoginOptions.VERSION: byte majorVersion = payload[payloadOffset]; byte minorVersion = payload[payloadOffset + 1]; - int level = (payload[payloadOffset + 2] << 8) | - payload[payloadOffset + 3]; is2005OrLater = majorVersion >= 9; if (!is2005OrLater) @@ -985,17 +987,13 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.ENCRYPT: + case (byte)PreLoginOptions.ENCRYPT: if (tlsFirst) { // Can skip/ignore this option if we are doing TDS 8. - offset += 4; break; } - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - EncryptionOptions serverOption = (EncryptionOptions)payload[payloadOffset]; /* internal enum EncryptionOptions { @@ -1048,11 +1046,8 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.INSTANCE: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - - byte ERROR_INST = 0x1; + case (byte)PreLoginOptions.INSTANCE: + const byte ERROR_INST = 0x1; byte instanceResult = payload[payloadOffset]; if (instanceResult == ERROR_INST) @@ -1065,29 +1060,21 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; - case (int)PreLoginOptions.THREADID: + case (byte)PreLoginOptions.THREADID: // DO NOTHING FOR THREADID - offset += 4; break; - case (int)PreLoginOptions.MARS: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - - marsCapable = (payload[payloadOffset] == 0 ? false : true); + case (byte)PreLoginOptions.MARS: + marsCapable = payload[payloadOffset] != 0; Debug.Assert(payload[payloadOffset] == 0 || payload[payloadOffset] == 1, "Value for Mars PreLoginHandshake option not equal to 1 or 0!"); break; - case (int)PreLoginOptions.TRACEID: + case (byte)PreLoginOptions.TRACEID: // DO NOTHING FOR TRACEID - offset += 4; break; case (int)PreLoginOptions.FEDAUTHREQUIRED: - payloadOffset = payload[offset++] << 8 | payload[offset++]; - payloadLength = payload[offset++] << 8 | payload[offset++]; - // Only 0x00 and 0x01 are accepted values from the server. if (payload[payloadOffset] != 0x00 && payload[payloadOffset] != 0x01) { @@ -1108,17 +1095,16 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ break; default: - Debug.Fail("UNKNOWN option in ConsumePreLoginHandshake, option:" + option); - // DO NOTHING FOR THESE UNKNOWN OPTIONS - offset += 4; - + Debug.Fail("UNKNOWN option in ConsumePreLoginHandshake, option:" + option); break; } - if (offset < payload.Length) + headerOffset += sizeof(ushort) + sizeof(ushort); + + if (headerOffset < payload.Length) { - option = payload[offset++]; + option = payload[headerOffset++]; } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 17729a4cc9..29e0308c7f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -931,12 +931,10 @@ internal bool TryProcessHeader() { // All read _partialHeaderBytesRead = 0; - _inBytesPacket = ((int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - (int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; + _inBytesPacket = BinaryPrimitives.ReadUInt16BigEndian(_partialHeaderBuffer.AsSpan(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - _inputHeaderLen; _messageStatus = _partialHeaderBuffer[1]; - _spid = _partialHeaderBuffer[TdsEnums.SPID_OFFSET] << 8 | - _partialHeaderBuffer[TdsEnums.SPID_OFFSET + 1]; + _spid = BinaryPrimitives.ReadUInt16BigEndian(_partialHeaderBuffer.AsSpan(TdsEnums.SPID_OFFSET)); SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); } @@ -972,10 +970,8 @@ internal bool TryProcessHeader() { // normal header processing... _messageStatus = _inBuff[_inBytesUsed + 1]; - _inBytesPacket = (_inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - _spid = _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET + 1]; + _inBytesPacket = BinaryPrimitives.ReadUInt16BigEndian(_inBuff.AsSpan(_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET)) - _inputHeaderLen; + _spid = BinaryPrimitives.ReadUInt16BigEndian(_inBuff.AsSpan(_inBytesUsed + TdsEnums.SPID_OFFSET)); #if !NETFRAMEWORK SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); #endif From 9423262f3197cd29499f266ec0459c9bf838f31c Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Wed, 27 Mar 2024 22:00:45 +0000 Subject: [PATCH 06/20] Adjusted synchronization on inner streams. If an SNINetworkStream is wrapped by an SNISslStream, there's already a layer of synchronisation - we don't need to keep the inner one enabled. --- .../SqlClient/SNI/SNIStreams.ValueTask.cs | 22 +++++++++++++++---- .../Data/SqlClient/SNI/SNIStreams.cs | 4 ++++ .../Data/SqlClient/SNI/SNITcpHandle.cs | 6 +++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs index f5f38f0efe..ff31c9183b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.ValueTask.cs @@ -66,7 +66,11 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + if (SynchronizeIO) + { + await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + try { return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); @@ -78,7 +82,10 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } finally { - _readAsyncSemaphore.Release(); + if (SynchronizeIO) + { + _readAsyncSemaphore.Release(); + } } } @@ -90,7 +97,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + if (SynchronizeIO) + { + await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + try { await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); @@ -102,7 +113,10 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella } finally { - _writeAsyncSemaphore.Release(); + if (SynchronizeIO) + { + _writeAsyncSemaphore.Release(); + } } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs index 389f25eeae..46e2f3af53 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs @@ -37,5 +37,9 @@ public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocke _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); } + + // This class is often wrapped in an SNISslStream, which also performs its own synchronisation. + // Setting this to false will disable the inner layer, since it's always synchronised by the wrapper. + public bool SynchronizeIO { get; set; } = true; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 37ad7217cd..796f00ade8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -28,7 +28,7 @@ internal sealed class SNITCPHandle : SNIPhysicalHandle private readonly string _targetServer; private readonly object _sendSync; private readonly Socket _socket; - private NetworkStream _tcpStream; + private SNINetworkStream _tcpStream; private readonly string _hostNameInCertificate; private readonly string _serverCertificateFilename; private readonly bool _tlsFirst; @@ -274,7 +274,7 @@ public override int ProtocolVersion _sslOverTdsStream = new SslOverTdsStream(_tcpStream, _connectionId); stream = _sslOverTdsStream; } - _sslStream = new SNISslStream(stream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate)); + _sslStream = new SNISslStream(stream, true, ValidateServerCertificate); } catch (SocketException se) { @@ -740,6 +740,7 @@ public override uint EnableSsl(uint options) } _stream = _sslStream; + _tcpStream.SynchronizeIO = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, SSL enabled successfully.", args0: _connectionId); return TdsEnums.SNI_SUCCESS; } @@ -750,6 +751,7 @@ public override uint EnableSsl(uint options) /// public override void DisableSsl() { + _tcpStream.SynchronizeIO = true; _sslStream.Flush(); _sslStream.Dispose(); _sslStream = null; From 2254e7ad769084204b4528970b5518ac31cbd8ea Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 28 Mar 2024 07:42:42 +0000 Subject: [PATCH 07/20] Improvements to MARS handling. * The SNI SMUX header is now a ref struct * Now using intrinsics to write out the message header Also updated the reading of header lengths to use intrinsics. --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 73 +++++++++---------- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 45 ++++++------ .../Data/SqlClient/SNI/SNIMarsHandle.cs | 47 ++++++------ .../SNI/SslOverTdsStream.NetCoreApp.cs | 5 +- .../SqlClient/TdsParserStateObjectManaged.cs | 2 +- 5 files changed, 80 insertions(+), 92 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index caf4af1717..a967a902fe 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -45,55 +45,50 @@ internal enum SNIProviders /// /// SMUX packet header /// - internal sealed class SNISMUXHeader + internal ref struct SNISMUXHeader { public const int HEADER_LENGTH = 16; - public byte SMID; - public byte flags; - public ushort sessionId; - public uint length; - public uint sequenceNumber; - public uint highwater; + public readonly byte Flags; + public readonly ushort SessionId; + public readonly uint Length; + public readonly uint SequenceNumber; + public readonly uint Highwater; - public void Read(byte[] bytes) + public SNISMUXHeader(byte flags, ushort sessionId, uint length, uint sequenceNumber, uint highwater) { - SMID = bytes[0]; - flags = bytes[1]; - Span span = bytes.AsSpan(); - sessionId = BinaryPrimitives.ReadUInt16LittleEndian(span.Slice(2)); - length = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(4)) - SNISMUXHeader.HEADER_LENGTH; - sequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(8)); - highwater = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(12)); + Flags = flags; + SessionId = sessionId; + Length = length; + SequenceNumber = sequenceNumber; + Highwater = highwater; + } + + public SNISMUXHeader(Span bytes) + { + // As per the MC-SMP spec, the first byte of the header will always be 0x53 + Debug.Assert(bytes[0] == 0x53, "First byte of the SNI SMUX header was not 0x53"); + + Flags = bytes[1]; + SessionId = BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(2)); + Length = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(4)) - SNISMUXHeader.HEADER_LENGTH; + SequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(8)); + Highwater = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(12)); } public void Write(Span bytes) { - uint value = highwater; // access the highest element first to cause the largest range check in the jit, then fill in the rest of the value and carry on as normal - bytes[15] = (byte)((value >> 24) & 0xff); - bytes[12] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.highwater).CopyTo(headerBytes, 12); - bytes[13] = (byte)((value >> 8) & 0xff); - bytes[14] = (byte)((value >> 16) & 0xff); - - bytes[0] = SMID; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0); - bytes[1] = flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1); - - value = sessionId; - bytes[2] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sessionId).CopyTo(headerBytes, 2); - bytes[3] = (byte)((value >> 8) & 0xff); - - value = length; - bytes[4] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.length).CopyTo(headerBytes, 4); - bytes[5] = (byte)((value >> 8) & 0xff); - bytes[6] = (byte)((value >> 16) & 0xff); - bytes[7] = (byte)((value >> 24) & 0xff); - - value = sequenceNumber; - bytes[8] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sequenceNumber).CopyTo(headerBytes, 8); - bytes[9] = (byte)((value >> 8) & 0xff); - bytes[10] = (byte)((value >> 16) & 0xff); - bytes[11] = (byte)((value >> 24) & 0xff); + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(12), Highwater); + + bytes[0] = 0x53; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0); + bytes[1] = Flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1); + + BinaryPrimitives.WriteUInt16LittleEndian(bytes.Slice(2), SessionId); + + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(4), Length); + + BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(8), SequenceNumber); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index f929a1ba32..b1bc76e09d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -17,7 +17,6 @@ internal class SNIMarsConnection private readonly Guid _connectionId; private readonly Dictionary _sessions; private readonly byte[] _headerBytes; - private readonly SNISMUXHeader _currentHeader; private readonly object _sync; private SNIHandle _lowerHandle; private ushort _nextSessionId; @@ -44,7 +43,6 @@ public SNIMarsConnection(SNIHandle lowerHandle) _connectionId = Guid.NewGuid(); _sessions = new Dictionary(); _headerBytes = new byte[SNISMUXHeader.HEADER_LENGTH]; - _currentHeader = new SNISMUXHeader(); _nextSessionId = 0; _currentHeaderByteCount = 0; _dataBytesLeft = 0; @@ -53,7 +51,7 @@ public SNIMarsConnection(SNIHandle lowerHandle) _lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete); } - public SNIMarsHandle CreateMarsSession(object callbackObject, bool async) + public SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async) { lock (DemuxerSync) { @@ -204,7 +202,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { using (TrySNIEventScope.Create(nameof(SNIMarsConnection))) { - SNISMUXHeader currentHeader = null; + SNISMUXHeader currentHeader = default; SNIPacket currentPacket = null; SNIMarsHandle currentSession = null; @@ -224,7 +222,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { if (_currentHeaderByteCount != SNISMUXHeader.HEADER_LENGTH) { - currentHeader = null; + currentHeader = default; currentPacket = null; currentSession = null; @@ -249,18 +247,17 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) } } - _currentHeader.Read(_headerBytes); - _dataBytesLeft = (int)_currentHeader.length; - _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)_currentHeader.length); + currentHeader = new SNISMUXHeader(_headerBytes); + _dataBytesLeft = (int)currentHeader.Length; + _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)currentHeader.Length); #if DEBUG SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: _currentHeader?.length); #endif } - currentHeader = _currentHeader; currentPacket = _currentPacket; - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_DATA) + if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { if (_dataBytesLeft > 0) { @@ -286,44 +283,44 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _currentHeaderByteCount = 0; - if (!_sessions.ContainsKey(_currentHeader.sessionId)) + if (!_sessions.ContainsKey(currentHeader.SessionId)) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.SMUX_PROV, 0, SNICommon.InvalidParameterError, Strings.SNI_ERROR_5); HandleReceiveError(packet); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: _currentHeader?.sessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: currentHeader.SessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); _lowerHandle.Dispose(); _lowerHandle = null; return; } - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_FIN) + if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_FIN) { - _sessions.Remove(_currentHeader.sessionId); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + _sessions.Remove(currentHeader.SessionId); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); } else { - currentSession = _sessions[_currentHeader.sessionId]; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession = _sessions[currentHeader.SessionId]; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); } } - if (currentHeader.flags == (byte)SNISMUXFlags.SMUX_DATA) + if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { - currentSession.HandleReceiveComplete(currentPacket, currentHeader); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession.HandleReceiveComplete(currentPacket, in currentHeader); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); } - if (_currentHeader.flags == (byte)SNISMUXFlags.SMUX_ACK) + if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_ACK) { try { - currentSession.HandleAck(currentHeader.highwater); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: _currentHeader?.sessionId); + currentSession.HandleAck(currentHeader.Highwater); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); } catch (Exception e) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: _currentHeader?.sessionId, args1: e?.Message); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: currentHeader.SessionId, args1: e?.Message); SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, SNICommon.InternalExceptionError, e); } #if DEBUG diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs index 8246ce3d6f..5997af6c0d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -20,12 +20,11 @@ internal sealed class SNIMarsHandle : SNIHandle private readonly uint _status = TdsEnums.SNI_UNINITIALIZED; private readonly Queue _receivedPacketQueue = new Queue(); private readonly Queue _sendPacketQueue = new Queue(); - private readonly object _callbackObject; + private readonly TdsParserStateObject _callbackObject; private readonly Guid _connectionId; private readonly ushort _sessionId; private readonly ManualResetEventSlim _packetEvent = new ManualResetEventSlim(false); private readonly ManualResetEventSlim _ackEvent = new ManualResetEventSlim(false); - private readonly SNISMUXHeader _currentHeader = new SNISMUXHeader(); private readonly SNIAsyncCallback _handleSendCompleteCallback; private uint _sendHighwater = 4; @@ -76,7 +75,7 @@ public override void Dispose() /// MARS session ID /// Callback object /// true if connection is asynchronous - public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object callbackObject, bool async) + public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, TdsParserStateObject callbackObject, bool async) { _sessionId = sessionId; _connection = connection; @@ -102,9 +101,7 @@ private void SendControlPacket(SNISMUXFlags flags) #endif lock (this) { - SetupSMUXHeader(0, flags); - _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); - packet.SetHeaderActive(); + SetupSMUXHeader(packet, flags); } _connection.Send(packet); @@ -116,17 +113,19 @@ private void SendControlPacket(SNISMUXFlags flags) } } - private void SetupSMUXHeader(int length, SNISMUXFlags flags) + private void SetupSMUXHeader(SNIPacket packet, SNISMUXFlags flags) { Debug.Assert(Monitor.IsEntered(this), "must take lock on self before updating smux header"); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to smux packet without smux reservation"); + + SNISMUXHeader header = new((byte)flags, _sessionId, SNISMUXHeader.HEADER_LENGTH + (uint)packet.Length, + sequenceNumber: ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++, + _receiveHighwater); - _currentHeader.SMID = 83; - _currentHeader.flags = (byte)flags; - _currentHeader.sessionId = _sessionId; - _currentHeader.length = (uint)SNISMUXHeader.HEADER_LENGTH + (uint)length; - _currentHeader.sequenceNumber = ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++; - _currentHeader.highwater = _receiveHighwater; - _receiveHighwaterLastAck = _currentHeader.highwater; + _receiveHighwaterLastAck = _receiveHighwater; + + header.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); } /// @@ -136,11 +135,7 @@ private void SetupSMUXHeader(int length, SNISMUXFlags flags) /// The packet with the SMUx header set. private SNIPacket SetPacketSMUXHeader(SNIPacket packet) { - Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to smux packet without smux reservation"); - - SetupSMUXHeader(packet.Length, SNISMUXFlags.SMUX_DATA); - _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); - packet.SetHeaderActive(); + SetupSMUXHeader(packet, SNISMUXFlags.SMUX_DATA); #if DEBUG SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, Setting SMUX_DATA header in current header for packet {1}", args0: ConnectionId, args1: packet?._id); #endif @@ -348,7 +343,7 @@ public void HandleReceiveError(SNIPacket packet) _packetEvent.Set(); } - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1); + _callbackObject.ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1); } } @@ -365,7 +360,7 @@ public void HandleSendComplete(SNIPacket packet, uint sniErrorCode) { Debug.Assert(_callbackObject != null); - ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode); + _callbackObject.WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode); } _connection.ReturnPacket(packet); #if DEBUG @@ -399,16 +394,16 @@ public void HandleAck(uint highwater) /// /// SNI packet /// SMUX header - public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header) + public void HandleReceiveComplete(SNIPacket packet, in SNISMUXHeader header) { using (TrySNIEventScope.Create(nameof(SNIMarsHandle))) { lock (this) { - if (_sendHighwater != header.highwater) + if (_sendHighwater != header.Highwater) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, header.highwater {1}, _sendHighwater {2}, Handle Ack with header.highwater", args0: ConnectionId, args1: header?.highwater, args2: _sendHighwater); - HandleAck(header.highwater); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, header.highwater {1}, _sendHighwater {2}, Handle Ack with header.highwater", args0: ConnectionId, args1: header.Highwater, args2: _sendHighwater); + HandleAck(header.Highwater); } lock (_receivedPacketQueue) @@ -425,7 +420,7 @@ public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header) Debug.Assert(_callbackObject != null); SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, _asyncReceives {3}", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater, args3: _asyncReceives); - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0); + _callbackObject.ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0); } _connection.ReturnPacket(packet); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index be8d1a0160..3e1b9e7019 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Threading; using System.Threading.Tasks; @@ -63,7 +64,7 @@ public override int Read(Span buffer) } while (headerBytesRead < TdsEnums.HEADER_LEN); // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + _packetBytes = BinaryPrimitives.ReadUInt16BigEndian(headerBytes.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - TdsEnums.HEADER_LEN; // read as much from the packet as the caller can accept int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); @@ -149,7 +150,7 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } while (headerBytesRead < TdsEnums.HEADER_LEN); // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + _packetBytes = BinaryPrimitives.ReadUInt16BigEndian(headerBytes.AsSpan(TdsEnums.HEADER_LEN_FIELD_OFFSET)) - TdsEnums.HEADER_LEN; ArrayPool.Shared.Return(headerBytes, clearArray: true); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 71d36bf586..22c64395f0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -57,7 +57,7 @@ protected override void CreateSessionHandle(TdsParserStateObject physicalConnect } } - internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) + internal SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async) { SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.CreateMarsSession | Info | State Object Id {0}, Session Id {1}, Async = {2}", _objectID, _sessionHandle?.ConnectionId, async); if (_marsConnection is null) From 87b6755f064077086e42d5c0400cf8acd6f298ea Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 28 Mar 2024 19:19:18 +0000 Subject: [PATCH 08/20] Bugfix to SNINativeMethodWrapper. Also corrected runnerconfig to allow it to bypass SSL verification by default (for default localhost usage.) --- .../netcore/src/Interop/SNINativeMethodWrapper.Windows.cs | 2 +- .../tests/PerformanceTests/runnerconfig.json | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index 568c9727d0..b74096fc66 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -434,7 +434,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo.Port.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == null ? null : cachedDNSInfo.Port.ToString(); if (spnBuffer != null) { diff --git a/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json b/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json index 47d1736127..76f073cb67 100644 --- a/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json +++ b/src/Microsoft.Data.SqlClient/tests/PerformanceTests/runnerconfig.json @@ -1,12 +1,12 @@ { - "ConnectionString": "Server=tcp:localhost; Integrated Security=true; Initial Catalog=sqlclient-perf-db;", + "ConnectionString": "Server=tcp:localhost; Integrated Security=true; Initial Catalog=sqlclient-perf-db; Trust Server Certificate=Yes;", "UseManagedSniOnWindows": false, "Benchmarks": { "SqlConnectionRunnerConfig": { "Enabled": true, "LaunchCount": 1, "IterationCount": 50, - "InvocationCount":30, + "InvocationCount": 30, "WarmupCount": 5, "RowCount": 0 }, From d4e6a4f2fef6a8ba339e56d616df758f87cf4fa0 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:45:04 +0000 Subject: [PATCH 09/20] Slight performance regression within MARS caused by instantiating an SNIMarsHeader struct rather than reading into it. Also caching UdpClient instances on a per-AddressFamily basis to prevent reallocating them for the same family. Removed the instantiation of one delegate instance to reduce memory usage slightly. --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 15 +- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 2 +- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 181 ++++++++++-------- .../SqlClient/TdsParserSafeHandles.Windows.cs | 4 +- 4 files changed, 113 insertions(+), 89 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index a967a902fe..a465505cf8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -49,11 +49,11 @@ internal enum SNIProviders { public const int HEADER_LENGTH = 16; - public readonly byte Flags; - public readonly ushort SessionId; - public readonly uint Length; - public readonly uint SequenceNumber; - public readonly uint Highwater; + public byte Flags; + public ushort SessionId; + public uint Length; + public uint SequenceNumber; + public uint Highwater; public SNISMUXHeader(byte flags, ushort sessionId, uint length, uint sequenceNumber, uint highwater) { @@ -64,7 +64,7 @@ public SNISMUXHeader(byte flags, ushort sessionId, uint length, uint sequenceNum Highwater = highwater; } - public SNISMUXHeader(Span bytes) + public void Read(Span bytes) { // As per the MC-SMP spec, the first byte of the header will always be 0x53 Debug.Assert(bytes[0] == 0x53, "First byte of the SNI SMUX header was not 0x53"); @@ -96,8 +96,7 @@ public void Write(Span bytes) /// /// SMUX packet flags /// - [Flags] - internal enum SNISMUXFlags + internal enum SNISMUXFlags : byte { SMUX_SYN = 1, // Begin SMUX connection SMUX_ACK = 2, // Acknowledge SMUX packets diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index b1bc76e09d..d1eba96d81 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -247,7 +247,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) } } - currentHeader = new SNISMUXHeader(_headerBytes); + currentHeader.Read(_headerBytes); _dataBytesLeft = (int)currentHeader.Length; _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)currentHeader.Length); #if DEBUG diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 8122906ae3..ab6ff09539 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -108,12 +108,12 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) using (TrySNIEventScope.Create(nameof(SSRP))) { const byte ClntUcastInst = 0x04; - instanceName += char.MinValue; int byteCount = Encoding.ASCII.GetByteCount(instanceName); - byte[] requestPacket = new byte[byteCount + 1]; + byte[] requestPacket = new byte[byteCount + 1 + 1]; requestPacket[0] = ClntUcastInst; Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 1); + requestPacket[byteCount + 1] = 0; return requestPacket; } @@ -175,13 +175,13 @@ private static byte[] CreateDacPortInfoRequest(string instanceName) const byte ClntUcastDac = 0x0F; const byte ProtocolVersion = 0x01; - instanceName += char.MinValue; int byteCount = Encoding.ASCII.GetByteCount(instanceName); - byte[] requestPacket = new byte[byteCount + 2]; + byte[] requestPacket = new byte[byteCount + 2 + 1]; requestPacket[0] = ClntUcastDac; requestPacket[1] = ProtocolVersion; Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 2); + requestPacket[2 + byteCount] = 0; return requestPacket; } @@ -319,134 +319,159 @@ private static async ValueTask SendUDPRequest(IList ipAddress if (ipAddresses.Count == 0) return null; - if (allIPsInParallel) // Used for MultiSubnetFailover + IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); + + if (allIPsInParallel && ipAddresses.Count > 1) // Used for MultiSubnetFailover { List> tasks = new(ipAddresses.Count); Task firstFailedTask = null; CancellationTokenSource cts = new CancellationTokenSource(); + // Cache the UdpClients for each of the address families to save disposing them + UdpClient ipv4UdpClient = null; + UdpClient ipv6UdpClient = null; for (int i = 0; i < ipAddresses.Count; i++) { - IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); - tasks.Add(SendUDPRequest(endPoint, requestPacket, async, cts.Token)); - } + if (i > 0) + { + endPoint.Address = ipAddresses[i]; + } - while (tasks.Count > 0) - { - Task completedTask; + if (endPoint.AddressFamily == AddressFamily.InterNetwork) + { + ipv4UdpClient ??= new UdpClient(AddressFamily.InterNetwork); - if (async) + tasks.Add(SendUDPRequest(endPoint, ipv4UdpClient, requestPacket, async, cts.Token)); + } + else if (endPoint.AddressFamily == AddressFamily.InterNetworkV6) { - completedTask = await Task.WhenAny(tasks).ConfigureAwait(false); + ipv6UdpClient ??= new UdpClient(AddressFamily.InterNetworkV6); - if (completedTask.Status == TaskStatus.RanToCompletion) - { - cts.Cancel(); - return completedTask.Result; - } + tasks.Add(SendUDPRequest(endPoint, ipv4UdpClient, requestPacket, async, cts.Token)); } - else + } + + using (ipv4UdpClient) + using (ipv6UdpClient) + { + while (tasks.Count > 0) { - int completedTaskIndex = Task.WaitAny(tasks.ToArray()); + Task completedTask; + + if (async) + { + completedTask = await Task.WhenAny(tasks).ConfigureAwait(false); - completedTask = tasks[completedTaskIndex]; - if (completedTask.Status == TaskStatus.RanToCompletion) + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } + } + else { - cts.Cancel(); - return completedTask.Result; + int completedTaskIndex = Task.WaitAny(tasks.ToArray()); + + completedTask = tasks[completedTaskIndex]; + if (completedTask.Status == TaskStatus.RanToCompletion) + { + cts.Cancel(); + return completedTask.Result; + } } - } - if (completedTask.Status == TaskStatus.Faulted) - { - tasks.Remove(completedTask); - firstFailedTask ??= completedTask; + if (completedTask.Status == TaskStatus.Faulted) + { + tasks.Remove(completedTask); + firstFailedTask ??= completedTask; + } } - } - Debug.Assert(firstFailedTask != null, "firstFailedTask should never be null"); + Debug.Assert(firstFailedTask != null, "firstFailedTask should never be null"); - // All tasks failed. Return the error from the first failure. - throw firstFailedTask.Exception; + // All tasks failed. Return the error from the first failure. + throw firstFailedTask.Exception; + } } else { - // If not parallel, use the first IP address provided - IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); - return await SendUDPRequest(endPoint, requestPacket, async, CancellationToken.None); + using (UdpClient oneShotUdpClient = new UdpClient(endPoint.AddressFamily)) + { + // If not parallel, use the first IP address provided + return await SendUDPRequest(endPoint, oneShotUdpClient, requestPacket, async, CancellationToken.None); + } } } #if NET6_0_OR_GREATER - private static async Task SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket, bool async, CancellationToken token) + private static async Task SendUDPRequest(IPEndPoint endPoint, UdpClient client, byte[] requestPacket, bool async, CancellationToken token) #else - private static Task SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket, bool async, CancellationToken token) + private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client, byte[] requestPacket, bool async, CancellationToken token) #endif { byte[] responsePacket = null; try { - using (UdpClient client = new UdpClient(endPoint.AddressFamily)) - { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); - using (CancellationTokenSource sendTimeoutCancellationTokenSource = new CancellationTokenSource(s_sendTimeout)) - using (CancellationTokenSource sendCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, sendTimeoutCancellationTokenSource.Token)) - { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); + + using (CancellationTokenSource sendTimeoutCancellationTokenSource = new CancellationTokenSource(s_sendTimeout)) + using (CancellationTokenSource sendCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, sendTimeoutCancellationTokenSource.Token)) + { #if NET6_0_OR_GREATER - ValueTask sendTask = client.SendAsync(requestPacket.AsMemory(), endPoint, sendCancellationTokenSource.Token); + ValueTask sendTask = client.SendAsync(requestPacket.AsMemory(), endPoint, sendCancellationTokenSource.Token); - if (async) - { - await sendTask.ConfigureAwait(false); - } - else + if (async) + { + await sendTask.ConfigureAwait(false); + } + else + { + if (!sendTask.IsCompleted) { - if (!sendTask.IsCompleted) - { - sendTask.AsTask().Wait(); - } + sendTask.AsTask().Wait(); } + } #else - Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); + Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); - sendTask.Wait(sendCancellationTokenSource.Token); + sendTask.Wait(sendCancellationTokenSource.Token); #endif - } + } - UdpReceiveResult receiveResult; + UdpReceiveResult receiveResult; - using (CancellationTokenSource receiveTimeoutCancellationTokenSource = new CancellationTokenSource(s_receiveTimeout)) - using (CancellationTokenSource receiveCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, receiveTimeoutCancellationTokenSource.Token)) - { + using (CancellationTokenSource receiveTimeoutCancellationTokenSource = new CancellationTokenSource(s_receiveTimeout)) + using (CancellationTokenSource receiveCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, receiveTimeoutCancellationTokenSource.Token)) + { #if NET6_0_OR_GREATER - ValueTask receiveTask = client.ReceiveAsync(receiveCancellationTokenSource.Token); + ValueTask receiveTask = client.ReceiveAsync(receiveCancellationTokenSource.Token); - if (async) + if (async) + { + receiveResult = await receiveTask.ConfigureAwait(false); + } + else + { + if (!receiveTask.IsCompleted) { - receiveResult = await receiveTask.ConfigureAwait(false); + receiveTask.AsTask().Wait(); } - else - { - if (!receiveTask.IsCompleted) - { - receiveTask.AsTask().Wait(); - } - receiveResult = receiveTask.Result; - } + receiveResult = receiveTask.Result; + } #else - Task receiveTask = client.ReceiveAsync(); + Task receiveTask = client.ReceiveAsync(); - receiveTask.Wait(receiveCancellationTokenSource.Token); - receiveResult = receiveTask.Result; + receiveTask.Wait(receiveCancellationTokenSource.Token); + receiveResult = receiveTask.Result; #endif - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); - responsePacket = receiveResult.Buffer; - } + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); + responsePacket = receiveResult.Buffer; } + } catch (OperationCanceledException) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index 9ca1ec4def..facb214a72 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -18,8 +18,8 @@ internal sealed partial class SNILoadHandle : SafeHandle { internal static readonly SNILoadHandle SingletonInstance = new SNILoadHandle(); - internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate ReadAsyncCallbackDispatcher = new SNINativeMethodWrapper.SqlAsyncCallbackDelegate(ReadDispatcher); - internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate WriteAsyncCallbackDispatcher = new SNINativeMethodWrapper.SqlAsyncCallbackDelegate(WriteDispatcher); + internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate ReadAsyncCallbackDispatcher = ReadDispatcher; + internal readonly SNINativeMethodWrapper.SqlAsyncCallbackDelegate WriteAsyncCallbackDispatcher = WriteDispatcher; private readonly uint _sniStatus = TdsEnums.SNI_UNINITIALIZED; private readonly EncryptionOptions _encryptionOption = EncryptionOptions.OFF; From ac3532935c28f5048544f82fa95f2717862d5c91 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 28 Mar 2024 23:23:20 +0000 Subject: [PATCH 10/20] Applied .NET Core changes to .NET Framework. One exception to this is that we're writing a byte array rather than a span to the packet - TdsParserStateObject doesn't support it there --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 184 ++++++++---------- 1 file changed, 81 insertions(+), 103 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 2170d02592..fce3d925ff 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -25,6 +25,7 @@ using Microsoft.Data.SqlTypes; using Microsoft.SqlServer.Server; using Microsoft.Data.ProviderBase; +using System.Buffers.Binary; namespace Microsoft.Data.SqlClient { @@ -1059,6 +1060,27 @@ internal void BestEffortCleanup() // prelogin so that we don't try to negotiate encryption again during ConsumePreLoginHandshake. _encryptionOption = EncryptionOptions.NOT_SUP; } + else + { + if (encrypt == SqlConnectionEncryptOption.Mandatory) + { + _encryptionOption = EncryptionOptions.ON; + } + else + { + _encryptionOption = EncryptionOptions.OFF; + } + + if (clientCertificate) + { + _encryptionOption |= EncryptionOptions.CLIENT_CERT; + } + } + + if (useCtaip) + { + _encryptionOption |= EncryptionOptions.CTAIP; + } // PreLoginHandshake buffer consists of: // 1) Standard header, with type = MT_PRELOGIN @@ -1072,147 +1094,100 @@ internal void BestEffortCleanup() // Initialize option offset into payload buffer // 5 bytes for each option (1 byte length, 2 byte offset, 2 byte payload length) - int offset = (int)PreLoginOptions.NUMOPT * 5 + 1; + ushort headerOffset = 0; + ushort headerLength = (ushort)PreLoginOptions.NUMOPT * 5; - byte[] payload = new byte[(int)PreLoginOptions.NUMOPT * 5 + TdsEnums.MAX_PRELOGIN_PAYLOAD_LENGTH]; - int payloadLength = 0; + ushort payloadStart = (ushort)(headerLength + 1); + ushort payloadOffset = payloadStart; + // The payload length is static for each connection string. The lengths of each option are well-known + // Version: 6 bytes; Encryption: 1 byte; Instance: (instanceName.Length + 1) bytes; + // Thread ID: 4 bytes; MARS enablement: 1 byte; Trace: (2 * GUID + 1 * uint); Federated Authentication: 1 byte + // End-of-payload marker: 1 byte + // .NET Core uses a static zero-length instance name, which suggests a 51-byte payload. + // .NET Framework allows a variable-length instance name (up to 254 bytes). This is up to 305 bytes at most. + int payloadLength = 6 + 1 + (instanceName.Length + 1) + 4 + 1 + (GUID_SIZE + GUID_SIZE + 4) + 1 + 1; - // UNDONE - need to do some length verification to ensure packet does not - // get too big!!! Not beyond it's max length! + int totalBufferLength = headerLength + payloadLength; + byte[] preLoginPacketBufferArray = new byte[totalBufferLength]; + Span preLoginPacketBuffer = preLoginPacketBufferArray; - for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++) + for (byte option = (byte)PreLoginOptions.VERSION; option < (byte)PreLoginOptions.NUMOPT; option++) { - int optionDataSize = 0; + ushort optionDataSize = 0; + // Structure header: // Fill in the option - _physicalStateObj.WriteByte((byte)option); + preLoginPacketBuffer[headerOffset] = option; // Fill in the offset of the option data - _physicalStateObj.WriteByte((byte)((offset & 0xff00) >> 8)); // send upper order byte - _physicalStateObj.WriteByte((byte)(offset & 0x00ff)); // send lower order byte + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1), payloadOffset); switch (option) { - case (int)PreLoginOptions.VERSION: + case (byte)PreLoginOptions.VERSION: Version systemDataVersion = ADP.GetAssemblyVersion(); // Major and minor - payload[payloadLength++] = (byte)(systemDataVersion.Major & 0xff); - payload[payloadLength++] = (byte)(systemDataVersion.Minor & 0xff); + preLoginPacketBuffer[payloadOffset] = (byte)(systemDataVersion.Major & 0xff); + preLoginPacketBuffer[payloadOffset + 1] = (byte)(systemDataVersion.Minor & 0xff); // Build (Big Endian) - payload[payloadLength++] = (byte)((systemDataVersion.Build & 0xff00) >> 8); - payload[payloadLength++] = (byte)(systemDataVersion.Build & 0xff); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(payloadOffset + 2), (ushort)(systemDataVersion.Build & 0xFFFF)); // Sub-build (Little Endian) - payload[payloadLength++] = (byte)(systemDataVersion.Revision & 0xff); - payload[payloadLength++] = (byte)((systemDataVersion.Revision & 0xff00) >> 8); - offset += 6; + BinaryPrimitives.WriteUInt16LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + 2 + sizeof(ushort)), (ushort)(systemDataVersion.Revision & 0xFFFF)); + optionDataSize = 6; break; - case (int)PreLoginOptions.ENCRYPT: - if (_encryptionOption == EncryptionOptions.NOT_SUP) - { - //If OS doesn't support encryption and encryption is not required, inform server "not supported" by client. - payload[payloadLength] = (byte)EncryptionOptions.NOT_SUP; - } - else - { - // Else, inform server of user request. - if (encrypt == SqlConnectionEncryptOption.Mandatory) - { - payload[payloadLength] = (byte)EncryptionOptions.ON; - _encryptionOption = EncryptionOptions.ON; - } - else - { - payload[payloadLength] = (byte)EncryptionOptions.OFF; - _encryptionOption = EncryptionOptions.OFF; - } - - // Inform server of user request. - if (clientCertificate) - { - payload[payloadLength] |= (byte)EncryptionOptions.CLIENT_CERT; - _encryptionOption |= EncryptionOptions.CLIENT_CERT; - } - } - - // Add CTAIP if requested. - if (useCtaip) - { - payload[payloadLength] |= (byte)EncryptionOptions.CTAIP; - _encryptionOption |= EncryptionOptions.CTAIP; - } + case (byte)PreLoginOptions.ENCRYPT: + preLoginPacketBuffer[payloadOffset] = (byte)_encryptionOption; - payloadLength += 1; - offset += 1; optionDataSize = 1; break; - case (int)PreLoginOptions.INSTANCE: - int i = 0; - - while (instanceName[i] != 0) - { - payload[payloadLength] = instanceName[i]; - payloadLength++; - i++; - } + case (byte)PreLoginOptions.INSTANCE: + instanceName.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); + preLoginPacketBuffer[payloadOffset + instanceName.Length] = 0; - payload[payloadLength] = 0; // null terminate - payloadLength++; - i++; - - offset += i; - optionDataSize = i; + optionDataSize = (ushort)(instanceName.Length + 1); break; - case (int)PreLoginOptions.THREADID: - Int32 threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); + case (byte)PreLoginOptions.THREADID: + int threadID = TdsParserStaticMethods.GetCurrentThreadIdForTdsLoginOnly(); + + BinaryPrimitives.WriteInt32BigEndian(preLoginPacketBuffer.Slice(payloadOffset), threadID); - payload[payloadLength++] = (byte)((0xff000000 & threadID) >> 24); - payload[payloadLength++] = (byte)((0x00ff0000 & threadID) >> 16); - payload[payloadLength++] = (byte)((0x0000ff00 & threadID) >> 8); - payload[payloadLength++] = (byte)(0x000000ff & threadID); - offset += 4; optionDataSize = 4; break; - case (int)PreLoginOptions.MARS: - payload[payloadLength++] = (byte)(_fMARS ? 1 : 0); - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.MARS: + preLoginPacketBuffer[payloadOffset] = (byte)(_fMARS ? 1 : 0); + + optionDataSize = 1; break; - case (int)PreLoginOptions.TRACEID: + case (byte)PreLoginOptions.TRACEID: byte[] connectionIdBytes = _connHandler._clientConnectionId.ToByteArray(); + Debug.Assert(GUID_SIZE == connectionIdBytes.Length); - Buffer.BlockCopy(connectionIdBytes, 0, payload, payloadLength, GUID_SIZE); - payloadLength += GUID_SIZE; - offset += GUID_SIZE; - optionDataSize = GUID_SIZE; + connectionIdBytes.CopyTo(preLoginPacketBuffer.Slice(payloadOffset)); ActivityCorrelator.ActivityId actId = ActivityCorrelator.Next(); + connectionIdBytes = actId.Id.ToByteArray(); - Buffer.BlockCopy(connectionIdBytes, 0, payload, payloadLength, GUID_SIZE); - payloadLength += GUID_SIZE; - payload[payloadLength++] = (byte)(0x000000ff & actId.Sequence); - payload[payloadLength++] = (byte)((0x0000ff00 & actId.Sequence) >> 8); - payload[payloadLength++] = (byte)((0x00ff0000 & actId.Sequence) >> 16); - payload[payloadLength++] = (byte)((0xff000000 & actId.Sequence) >> 24); - int actIdSize = GUID_SIZE + sizeof(UInt32); - offset += actIdSize; - optionDataSize += actIdSize; + connectionIdBytes.CopyTo(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE)); + BinaryPrimitives.WriteUInt32LittleEndian(preLoginPacketBuffer.Slice(payloadOffset + GUID_SIZE + GUID_SIZE), actId.Sequence); + + optionDataSize = GUID_SIZE + GUID_SIZE + sizeof(uint); SqlClientEventSource.Log.TryTraceEvent(" ClientConnectionID {0}, ActivityID {1}", _connHandler._clientConnectionId, actId); break; - case (int)PreLoginOptions.FEDAUTHREQUIRED: - payload[payloadLength++] = 0x01; - offset += 1; - optionDataSize += 1; + case (byte)PreLoginOptions.FEDAUTHREQUIRED: + preLoginPacketBuffer[payloadOffset] = 0x01; + + optionDataSize = 1; break; default: @@ -1220,19 +1195,22 @@ internal void BestEffortCleanup() break; } + payloadOffset += optionDataSize; + // Write data length - _physicalStateObj.WriteByte((byte)((optionDataSize & 0xff00) >> 8)); - _physicalStateObj.WriteByte((byte)(optionDataSize & 0x00ff)); + BinaryPrimitives.WriteUInt16BigEndian(preLoginPacketBuffer.Slice(headerOffset + 1 + sizeof(ushort)), optionDataSize); + + headerOffset += 1 + sizeof(ushort) + sizeof(ushort); } // Write out last option - to let server know the second part of packet completed - _physicalStateObj.WriteByte((byte)PreLoginOptions.LASTOPT); + preLoginPacketBuffer[headerOffset] = (byte)PreLoginOptions.LASTOPT; // Write out payload - _physicalStateObj.WriteByteArray(payload, payloadLength, 0); + _physicalStateObj.WriteByteArray(preLoginPacketBufferArray, preLoginPacketBufferArray.Length, 0); // Flush packet - _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); + _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH)?.Wait(); } private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integratedSecurity, string serverCertificate, ServerCertificateValidationCallback serverCallback, ClientCertificateRetrievalCallback clientCallback) From 462d3e83a1ac186ece517e7eb135c48db42a4f02 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Fri, 29 Mar 2024 13:21:30 +0000 Subject: [PATCH 11/20] Minor code cleanup in SSRP. Removing one implicit string allocation if a UDP request fails and tracing is disabled. --- .../netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index ab6ff09539..fd9b4ec87a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -492,7 +492,7 @@ private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client firstSocketException = e; } SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: e.Message); } // Throw first error if we didn't find a SocketException @@ -501,14 +501,14 @@ private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client else { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: ae.Message); + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: ae.Message); throw; } } catch (Exception e) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, - "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint.ToString(), args1: e.Message); + "SendUDPRequest ({0}) resulted in exception: {1}", args0: endPoint, args1: e.Message); throw; } From a7e4f88ceef2f6be1dcf2b480b37eddc948c1dd0 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Fri, 5 Apr 2024 06:13:55 +0100 Subject: [PATCH 12/20] Adjusting DNS caching. This commit lays the groundwork to eliminate the second DNS lookup required after SSRP. --- .../Interop/SNINativeMethodWrapper.Windows.cs | 8 +- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 94 ++++++++----- .../Data/SqlClient/SNI/SNITcpHandle.cs | 75 ++--------- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 123 ++++++------------ .../SqlClient/TdsParserStateObjectNative.cs | 4 +- .../Data/Interop/SNINativeMethodWrapper.cs | 8 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 +- .../Data/SqlClient/SQLFallbackDNSCache.cs | 23 ++-- 8 files changed, 140 insertions(+), 199 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index b74096fc66..e680f2853a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -387,8 +387,8 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); @@ -432,8 +432,8 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == null ? null : cachedDNSInfo.Port.ToString(); if (spnBuffer != null) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index a465505cf8..d5859731fa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -13,6 +13,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.ProviderBase; +using System.Net.Sockets; namespace Microsoft.Data.SqlClient.SNI { @@ -326,28 +327,14 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5 } } - internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout) - { - using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) - { - int remainingTimeout = timeout.MillisecondsRemainingInt; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, - "Getting DNS host entries for serverName {0} within {1} milliseconds.", - args0: serverName, - args1: remainingTimeout); - using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); - // using this overload to support netstandard - Task task = Dns.GetHostAddressesAsync(serverName); - task.ConfigureAwait(false); - task.Wait(cts.Token); - return task.Result; - } - } - + /// + /// Returns array of IP addresses for the given server name, sorted according to the given preference. + /// + /// Thrown when ipPreference is not supported #if NET6_0_OR_GREATER - internal static async ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, bool async) + internal static async ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async) #else - internal static ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, bool async) + internal static ValueTask GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async) #endif { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) @@ -364,41 +351,86 @@ internal static ValueTask GetDnsIpAddresses(string serverName, Time if (async) { - return await task.ConfigureAwait(false); + return SortIpAddressesByPreference(await task.ConfigureAwait(false), ipPreference); } else { task.Wait(); - return task.Result; + return SortIpAddressesByPreference(task.Result, ipPreference); } #else // using this overload to support netstandard Task task = Dns.GetHostAddressesAsync(serverName); task.Wait(cts.Token); - return new ValueTask(task.Result); + return new ValueTask(SortIpAddressesByPreference(task.Result, ipPreference)); #endif } } - internal static IPAddress[] GetDnsIpAddresses(string serverName) + /// + /// Returns array of IP addresses for the given server name, sorted according to the given preference. + /// + /// Thrown when ipPreference is not supported + internal static async ValueTask GetDnsIpAddresses(string serverName, SqlConnectionIPAddressPreference ipPreference, bool async) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName); - return Dns.GetHostAddresses(serverName); + + return SortIpAddressesByPreference(async + ? await Dns.GetHostAddressesAsync(serverName) + : Dns.GetHostAddresses(serverName), + ipPreference); } } - internal static async ValueTask GetDnsIpAddresses(string serverName, bool async) + private static IPAddress[] SortIpAddressesByPreference(IPAddress[] dnsIPAddresses, SqlConnectionIPAddressPreference ipPreference) { - using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + AddressFamily? prioritiesFamily = ipPreference switch { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName); + SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, + SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6, + SqlConnectionIPAddressPreference.UsePlatformDefault => null, + _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(SortIpAddressesByPreference)) + }; - return async - ? await Dns.GetHostAddressesAsync(serverName) - : Dns.GetHostAddresses(serverName); + if (prioritiesFamily == null) + { + return dnsIPAddresses; + } + else + { + int resultArrayIndex = 0; + IPAddress[] ipAddresses = new IPAddress[dnsIPAddresses.Length]; + + // Return addresses of the preferred family first + for (int i = 0; i < dnsIPAddresses.Length; i++) + { + if (dnsIPAddresses[i].AddressFamily == prioritiesFamily) + { + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; + } + } + + // Return addresses of the other family + for (int i = 0; i < dnsIPAddresses.Length; i++) + { + if (dnsIPAddresses[i].AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 + && dnsIPAddresses[i].AddressFamily != prioritiesFamily) + { + ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; + } + } + + // If the DNS resolution returned records of types other than A and AAAA, the original array size will be + // too large, and must thus be resized. This is very unlikely, so we only try to do this post-hoc. + if (resultArrayIndex + 1 < ipAddresses.Length) + { + Array.Resize(ref ipAddresses, resultArrayIndex + 1); + } + + return ipAddresses; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 796f00ade8..202e0cc0c4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -188,20 +188,20 @@ public override int ProtocolVersion else { int portRetry = cachedDNSInfo.Port == 0 ? port : cachedDNSInfo.Port; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying with cached DNS IP Address {1} and port {2}", args0: _connectionId, args1: cachedDNSInfo.AddrIPv4, args2: cachedDNSInfo.Port); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying with cached DNS IP Address {1} and port {2}", args0: _connectionId, args1: cachedDNSInfo.CachedIPv4Address, args2: cachedDNSInfo.Port); IPAddress[] firstCachedIP; IPAddress[] secondCachedIP; if (SqlConnectionIPAddressPreference.IPv6First == ipPreference) { - firstCachedIP = new[] { cachedDNSInfo.AddrIPv6 }; - secondCachedIP = new[] { cachedDNSInfo.AddrIPv4 }; + firstCachedIP = new[] { cachedDNSInfo.CachedIPv6Address }; + secondCachedIP = new[] { cachedDNSInfo.CachedIPv4Address }; } else { - firstCachedIP = new[] { cachedDNSInfo.AddrIPv4 }; - secondCachedIP = new[] { cachedDNSInfo.AddrIPv6 }; + firstCachedIP = new[] { cachedDNSInfo.CachedIPv4Address }; + secondCachedIP = new[] { cachedDNSInfo.CachedIPv6Address }; } try @@ -300,9 +300,10 @@ public override int ProtocolVersion // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { - IPAddress[] serverAddresses = timeout.IsInfinite - ? SNICommon.GetDnsIpAddresses(hostName) - : SNICommon.GetDnsIpAddresses(hostName, timeout); + IPAddress[] serverAddresses = pendingDNSInfo?.SpeculativeIPAddresses ?? + (timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(hostName, SqlConnectionIPAddressPreference.UsePlatformDefault, false).Result + : SNICommon.GetDnsIpAddresses(hostName, timeout, SqlConnectionIPAddressPreference.UsePlatformDefault, false).Result); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -355,68 +356,12 @@ private Socket TryConnectParallel(IPAddress[] serverAddresses, int port, Timeout return availableSocket; } - /// - /// Returns array of IP addresses for the given server name, sorted according to the given preference. - /// - /// Thrown when ipPreference is not supported - private static IPAddress[] GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference) - { - IPAddress[] dnsIPAddresses = Dns.GetHostAddresses(serverName); - IPAddress[] ipAddresses; - AddressFamily? prioritiesFamily = ipPreference switch - { - SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, - SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6, - SqlConnectionIPAddressPreference.UsePlatformDefault => null, - _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference)) - }; - - if (prioritiesFamily == null) - { - ipAddresses = dnsIPAddresses; - } - else - { - int resultArrayIndex = 0; - - ipAddresses = new IPAddress[dnsIPAddresses.Length]; - - // Return addresses of the preferred family first - for (int i = 0; i < dnsIPAddresses.Length; i++) - { - if (dnsIPAddresses[i].AddressFamily == prioritiesFamily) - { - ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; - } - } - - // Return addresses of the other family - for (int i = 0; i < dnsIPAddresses.Length; i++) - { - if (dnsIPAddresses[i].AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 - && dnsIPAddresses[i].AddressFamily != prioritiesFamily) - { - ipAddresses[resultArrayIndex++] = dnsIPAddresses[i]; - } - } - - // If the DNS resolution returned records of types other than A and AAAA, the original array size will be - // too large, and must thus be resized. This is very unlikely, so we only try to do this post-hoc. - if (resultArrayIndex + 1 < ipAddresses.Length) - { - Array.Resize(ref ipAddresses, resultArrayIndex + 1); - } - } - - return ipAddresses; - } - // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. private static Socket Connect(string serverName, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { - IPAddress[] ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); + IPAddress[] ipAddresses = pendingDNSInfo?.SpeculativeIPAddresses ?? SNICommon.GetDnsIpAddresses(serverName, ipPreference, false).Result; return Connect(ipAddresses, port, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index fd9b4ec87a..61aaf3744f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -16,7 +16,6 @@ namespace Microsoft.Data.SqlClient.SNI { internal sealed class SSRP { - private static readonly List s_emptyList = new(0); private static readonly TimeSpan s_sendTimeout = TimeSpan.FromSeconds(1.0); private static readonly TimeSpan s_receiveTimeout = TimeSpan.FromSeconds(1.0); @@ -211,83 +210,69 @@ private static async ValueTask SendUDPRequest(string browserHostname, in } IPAddress[] ipAddresses = await (timeout.IsInfinite - ? SNICommon.GetDnsIpAddresses(browserHostname, async) - : SNICommon.GetDnsIpAddresses(browserHostname, timeout, async)); + ? SNICommon.GetDnsIpAddresses(browserHostname, ipPreference, async) + : SNICommon.GetDnsIpAddresses(browserHostname, timeout, ipPreference, async)); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); - List ipv4Addresses = null; - byte[] response4 = null; - - List ipv6Addresses = null; - byte[] response6 = null; + byte[] response = null; Exception responseException = null; switch (ipPreference) { case SqlConnectionIPAddressPreference.IPv4First: + case SqlConnectionIPAddressPreference.IPv6First: { - SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); - - try + // If the ipPreference has been specified and IP addresses of a certain address family are attempted first, + // then slice the array at the point where the address family changes. + AddressFamily previousAddressFamily = ipAddresses[0].AddressFamily; + int firstAddressFamilyLength = 0; + Memory primaryIpAddressList; + Memory secondaryIpAddressList = Memory.Empty; + + for (int i = 0; i < ipAddresses.Length; i++) { - response4 = await SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); - - if (response4 != null) + if (ipAddresses[i].AddressFamily != previousAddressFamily) { - return response4; + firstAddressFamilyLength = i; + break; } + + previousAddressFamily = ipAddresses[i].AddressFamily; } - catch(Exception e) - { responseException ??= e; } + primaryIpAddressList = firstAddressFamilyLength == 0 + ? ipAddresses.AsMemory() + : ipAddresses.AsMemory(0, firstAddressFamilyLength); try { - response6 = await SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + response = await SendUDPRequest(primaryIpAddressList, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); - if (response6 != null) + if (response != null) { - return response6; + return response; } } catch (Exception e) { responseException ??= e; } - // No responses so throw first error - if (responseException != null) - { - throw responseException; - } - break; - } - case SqlConnectionIPAddressPreference.IPv6First: - { - SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); - - try + if (firstAddressFamilyLength > 0) { - response6 = await SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); + secondaryIpAddressList = ipAddresses.AsMemory(firstAddressFamilyLength); - if (response6 != null) + try { - return response6; - } - } - catch (Exception e) - { responseException ??= e; } + response = await SendUDPRequest(secondaryIpAddressList, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); - try - { - response4 = await SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel, async).ConfigureAwait(false); - - if (response4 != null) - { - return response4; + if (response != null) + { + return response; + } } + catch (Exception e) + { responseException ??= e; } } - catch (Exception e) - { responseException ??= e; } // No responses so throw first error if (responseException != null) @@ -314,27 +299,27 @@ private static async ValueTask SendUDPRequest(string browserHostname, in /// query all resolved IP addresses in parallel /// If true, this method will be run asynchronously /// response packet from UDP server - private static async ValueTask SendUDPRequest(IList ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) + private static async ValueTask SendUDPRequest(Memory ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel, bool async) { - if (ipAddresses.Count == 0) + if (ipAddresses.IsEmpty) return null; - IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); + IPEndPoint endPoint = new IPEndPoint(ipAddresses.Span[0], port); - if (allIPsInParallel && ipAddresses.Count > 1) // Used for MultiSubnetFailover + if (allIPsInParallel && ipAddresses.Length > 1) // Used for MultiSubnetFailover { - List> tasks = new(ipAddresses.Count); + List> tasks = new(ipAddresses.Length); Task firstFailedTask = null; CancellationTokenSource cts = new CancellationTokenSource(); // Cache the UdpClients for each of the address families to save disposing them UdpClient ipv4UdpClient = null; UdpClient ipv6UdpClient = null; - for (int i = 0; i < ipAddresses.Count; i++) + for (int i = 0; i < ipAddresses.Length; i++) { if (i > 0) { - endPoint.Address = ipAddresses[i]; + endPoint.Address = ipAddresses.Span[i]; } if (endPoint.AddressFamily == AddressFamily.InterNetwork) @@ -518,33 +503,5 @@ private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client return Task.FromResult(responsePacket); #endif } - - private static void SplitIPv4AndIPv6(IPAddress[] input, out List ipv4Addresses, out List ipv6Addresses) - { - List v4 = null; - List v6 = null; - - if (input != null && input.Length > 0) - { - v4 = new List(1); - v6 = new List(0); - - for (int index = 0; index < input.Length; index++) - { - switch (input[index].AddressFamily) - { - case AddressFamily.InterNetwork: - v4.Add(input[index]); - break; - case AddressFamily.InterNetworkV6: - v6.Add(input[index]); - break; - } - } - } - - ipv4Addresses = v4 ?? s_emptyList; - ipv6Addresses = v6 ?? s_emptyList; - } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 7104b22ce6..dffe5ac250 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -107,11 +107,11 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv4 = IPFromSNI; + pendingDNSInfo.CachedIPv4Address = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - pendingDNSInfo.AddrIPv6 = IPFromSNI; + pendingDNSInfo.CachedIPv6Address = IPFromSNI; } } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index 9c371692ca..1bce41ddd8 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -1095,8 +1095,8 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); @@ -1154,8 +1154,8 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan clientConsumerInfo.ipAddressPreference = ipPreference; clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4?.ToString(); - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString(); + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString(); clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString(); if (spnBuffer != null) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index fce3d925ff..504a555924 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -861,11 +861,11 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) { if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv4 = IPFromSNI; + _connHandler.pendingSQLDNSObject.CachedIPv4Address = IPFromSNI; } else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) { - _connHandler.pendingSQLDNSObject.AddrIPv6 = IPFromSNI; + _connHandler.pendingSQLDNSObject.CachedIPv6Address = IPFromSNI; } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs index 6ce612080e..5732c83092 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs @@ -57,8 +57,8 @@ internal bool IsDuplicate(SQLDNSInfo newItem) SQLDNSInfo oldItem; if (GetDNSInfo(newItem.FQDN, out oldItem)) { - return (newItem.AddrIPv4 == oldItem.AddrIPv4 && - newItem.AddrIPv6 == oldItem.AddrIPv6 && + return (newItem.CachedIPv4Address == oldItem.CachedIPv4Address && + newItem.CachedIPv6Address == oldItem.CachedIPv6Address && newItem.Port == oldItem.Port); } } @@ -70,16 +70,23 @@ internal bool IsDuplicate(SQLDNSInfo newItem) internal sealed class SQLDNSInfo { public string FQDN { get; set; } - public IPAddress AddrIPv4 { get; set; } - public IPAddress AddrIPv6 { get; set; } + public IPAddress CachedIPv4Address { get; set; } + public IPAddress CachedIPv6Address { get; set; } public int Port { get; set; } + public IPAddress[] SpeculativeIPAddresses { get; set; } - internal SQLDNSInfo(string FQDN, IPAddress ipv4, IPAddress ipv6, int port) + internal SQLDNSInfo(string fqdn, IPAddress ipv4, IPAddress ipv6, int port) { - this.FQDN = FQDN; - AddrIPv4 = ipv4; - AddrIPv6 = ipv6; + FQDN = fqdn; + CachedIPv4Address = ipv4; + CachedIPv6Address = ipv6; Port = port; } + + internal SQLDNSInfo(string fqdn, IPAddress[] speculativeIPAddresses) + { + FQDN = fqdn; + SpeculativeIPAddresses = speculativeIPAddresses; + } } } From 2dc88399d62ea90cc676f744962b902f5d53c316 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Fri, 5 Apr 2024 21:31:52 +0100 Subject: [PATCH 13/20] Eliminating unneeded DNS lookup. If a DNS lookup was needed for SSRP, forward its results to the TCP connection creation --- .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 6 +- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 78 ++++++++++---- .../Microsoft/Data/SqlClient/SQLDNSCache.cs | 101 ++++++++++++++++++ 3 files changed, 162 insertions(+), 23 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index a0e1310b94..cdab2f8eee 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -313,9 +313,13 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr { try { - details.ResolvedPort = port = isAdminConnection ? + SSRP.SSRPResult portDetails = isAdminConnection ? SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) : SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference); + + + details.ResolvedPort = port = portDetails.Port; + pendingDNSInfo = new SQLDNSInfo(hostName, portDetails.ResolvedIPAddresses); } catch (SocketException se) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 61aaf3744f..73d975028d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics; using System.Net; @@ -10,12 +11,28 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure; using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { internal sealed class SSRP { + public sealed class SSRPResult + { + internal byte[] Buffer { get; set; } + + public ushort Port { get; set; } + + public IPAddress[] ResolvedIPAddresses { get; } + + public SSRPResult(IPAddress[] resolvedIPAddresses, byte[] buffer) + { + ResolvedIPAddresses = resolvedIPAddresses; + Buffer = buffer; + } + } + private static readonly TimeSpan s_sendTimeout = TimeSpan.FromSeconds(1.0); private static readonly TimeSpan s_receiveTimeout = TimeSpan.FromSeconds(1.0); @@ -36,8 +53,8 @@ internal sealed class SSRP /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference - /// port number for given instance name - internal static int GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + /// port number and resolved IP addresses for given instance name + internal static SSRPResult GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; /// @@ -48,21 +65,23 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference - /// port number for given instance name - internal static ValueTask GetPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + /// port number and resolved IP addresses for given instance name + internal static ValueTask GetPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) => GetPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); - private static async ValueTask GetPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) + private static async ValueTask GetPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); using (TrySNIEventScope.Create(nameof(SSRP))) { byte[] instanceInfoRequest = CreateInstanceInfoRequest(instanceName); + SSRPResult response = null; byte[] responsePacket = null; try { - responsePacket = await SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference, async); + response = await SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference, async); + responsePacket = response?.Buffer; } catch (SocketException se) { @@ -92,7 +111,10 @@ private static async ValueTask GetPortByInstanceNameCore(string browserHost throw new SocketException(); } - return ushort.Parse(elements[tcpIndex + 1]); + response.Port = ushort.Parse(elements[tcpIndex + 1]); + response.Buffer = null; + + return response; } } @@ -127,7 +149,7 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name - internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static SSRPResult GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, false).Result; /// @@ -139,16 +161,17 @@ internal static int GetDacPortByInstanceName(string browserHostName, string inst /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name - internal static ValueTask GetDacPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static ValueTask GetDacPortByInstanceNameAsync(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) => GetDacPortByInstanceNameCore(browserHostName, instanceName, timeout, allIPsInParallel, ipPreference, true); - private static async ValueTask GetDacPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) + private static async ValueTask GetDacPortByInstanceNameCore(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = await SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference, async); + SSRPResult response = await SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference, async); + byte[] responsePacket = response?.Buffer; const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -159,8 +182,9 @@ private static async ValueTask GetDacPortByInstanceNameCore(string browserH throw new SocketException(); } - int dacPort = BitConverter.ToUInt16(responsePacket, 4); - return dacPort; + response.Port = BitConverter.ToUInt16(responsePacket, 4); + response.Buffer = null; + return response; } /// @@ -196,7 +220,7 @@ private static byte[] CreateDacPortInfoRequest(string instanceName) /// IP address preference /// If true, this method will be run asynchronously /// response packet from UDP server - private static async ValueTask SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) + private static async ValueTask SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference, bool async) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -204,14 +228,18 @@ private static async ValueTask SendUDPRequest(string browserHostname, in Debug.Assert(port >= 0 && port <= 65535, "Invalid port"); Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array"); + IPAddress[] ipAddresses; + if (IPAddress.TryParse(browserHostname, out IPAddress address)) { - return await SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel, async); + ipAddresses = new IPAddress[1] { address }; + } + else + { + ipAddresses = await (timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(browserHostname, ipPreference, async) + : SNICommon.GetDnsIpAddresses(browserHostname, timeout, ipPreference, async)); } - - IPAddress[] ipAddresses = await (timeout.IsInfinite - ? SNICommon.GetDnsIpAddresses(browserHostname, ipPreference, async) - : SNICommon.GetDnsIpAddresses(browserHostname, timeout, ipPreference, async)); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); @@ -250,7 +278,7 @@ private static async ValueTask SendUDPRequest(string browserHostname, in if (response != null) { - return response; + return new SSRPResult(ipAddresses, response); } } catch (Exception e) @@ -267,7 +295,7 @@ private static async ValueTask SendUDPRequest(string browserHostname, in if (response != null) { - return response; + return new SSRPResult(ipAddresses, response); } } catch (Exception e) @@ -283,7 +311,13 @@ private static async ValueTask SendUDPRequest(string browserHostname, in break; } default: - return await SendUDPRequest(ipAddresses, port, requestPacket, true, async).ConfigureAwait(false); // allIPsInParallel); + byte[] buffer = await SendUDPRequest(ipAddresses, port, requestPacket, true, async).ConfigureAwait(false); + + if (response != null) + { + return new SSRPResult(ipAddresses, response); + } + break; } return null; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs new file mode 100644 index 0000000000..0f8695a4ce --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Net; + +namespace Microsoft.Data.SqlClient +{ + internal sealed class SQLDNSCache + { + // These should be prime numbers according to MSDN docs. ConcurrentDictionary will be resized if the capacity is reached. + // SqlBrowserCacheCapacity doesn't need to be very large at all. SSRP writes to this as part of resolving an instance name + // to a port, so the rest of the connection process doesn't need to + private const int FallbackCacheCapacity = 101; + private const int SqlBrowserCacheCapacity = 11; + + private static readonly SQLDNSCache s_sqlFallbackDNSCache = new(FallbackCacheCapacity); + private static readonly SQLDNSCache s_sqlBrowserDNSCache = new(SqlBrowserCacheCapacity); + + private readonly ConcurrentDictionary _dnsInfoCache; + + // singleton instance + public static SQLDNSCache Instance => s_sqlFallbackDNSCache; + + public static SQLDNSCache InterimInstance => s_sqlFallbackDNSCache; + + private SQLDNSCache(int initialCapacity) + { + int level = 4 * Environment.ProcessorCount; + _dnsInfoCache = new ConcurrentDictionary(concurrencyLevel: level, + capacity: initialCapacity, + comparer: StringComparer.OrdinalIgnoreCase); + } + + internal bool AddDNSInfo(SQLDNSInfo item) + { + if (null != item) + { +#if NET6_0_OR_GREATER || NETSTANDARD2_1 + _dnsInfoCache.AddOrUpdate(item.FQDN, static (key, state) => state, static (key, value, state) => state, item); +#else + _dnsInfoCache.AddOrUpdate(item.FQDN, item, (key, value) => item); +#endif + return true; + } + + return false; + } + + internal bool DeleteDNSInfo(string FQDN) + { + return _dnsInfoCache.TryRemove(FQDN, out _); + } + + internal bool GetDNSInfo(string FQDN, out SQLDNSInfo result) + { + return _dnsInfoCache.TryGetValue(FQDN, out result); + } + + internal bool IsDuplicate(SQLDNSInfo newItem) + { + if (null != newItem) + { + SQLDNSInfo oldItem; + if (GetDNSInfo(newItem.FQDN, out oldItem)) + { + return (newItem.CachedIPv4Address == oldItem.CachedIPv4Address && + newItem.CachedIPv6Address == oldItem.CachedIPv6Address && + newItem.Port == oldItem.Port); + } + } + + return false; + } + } + + internal sealed class SQLDNSInfo + { + public string FQDN { get; set; } + public IPAddress CachedIPv4Address { get; set; } + public IPAddress CachedIPv6Address { get; set; } + public int Port { get; set; } + public IPAddress[] SpeculativeIPAddresses { get; set; } + + internal SQLDNSInfo(string fqdn, IPAddress ipv4, IPAddress ipv6, int port) + { + FQDN = fqdn; + CachedIPv4Address = ipv4; + CachedIPv6Address = ipv6; + Port = port; + } + + internal SQLDNSInfo(string fqdn, IPAddress[] speculativeIPAddresses) + { + FQDN = fqdn; + SpeculativeIPAddresses = speculativeIPAddresses; + } + } +} From 675bdb2c67696894002e80a33249569447e6d2fc Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Sun, 7 Apr 2024 06:12:40 +0100 Subject: [PATCH 14/20] Cleaning up PR diff - removed accidentally committed file. --- .../Microsoft/Data/SqlClient/SQLDNSCache.cs | 101 ------------------ 1 file changed, 101 deletions(-) delete mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs deleted file mode 100644 index 0f8695a4ce..0000000000 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLDNSCache.cs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Concurrent; -using System.Net; - -namespace Microsoft.Data.SqlClient -{ - internal sealed class SQLDNSCache - { - // These should be prime numbers according to MSDN docs. ConcurrentDictionary will be resized if the capacity is reached. - // SqlBrowserCacheCapacity doesn't need to be very large at all. SSRP writes to this as part of resolving an instance name - // to a port, so the rest of the connection process doesn't need to - private const int FallbackCacheCapacity = 101; - private const int SqlBrowserCacheCapacity = 11; - - private static readonly SQLDNSCache s_sqlFallbackDNSCache = new(FallbackCacheCapacity); - private static readonly SQLDNSCache s_sqlBrowserDNSCache = new(SqlBrowserCacheCapacity); - - private readonly ConcurrentDictionary _dnsInfoCache; - - // singleton instance - public static SQLDNSCache Instance => s_sqlFallbackDNSCache; - - public static SQLDNSCache InterimInstance => s_sqlFallbackDNSCache; - - private SQLDNSCache(int initialCapacity) - { - int level = 4 * Environment.ProcessorCount; - _dnsInfoCache = new ConcurrentDictionary(concurrencyLevel: level, - capacity: initialCapacity, - comparer: StringComparer.OrdinalIgnoreCase); - } - - internal bool AddDNSInfo(SQLDNSInfo item) - { - if (null != item) - { -#if NET6_0_OR_GREATER || NETSTANDARD2_1 - _dnsInfoCache.AddOrUpdate(item.FQDN, static (key, state) => state, static (key, value, state) => state, item); -#else - _dnsInfoCache.AddOrUpdate(item.FQDN, item, (key, value) => item); -#endif - return true; - } - - return false; - } - - internal bool DeleteDNSInfo(string FQDN) - { - return _dnsInfoCache.TryRemove(FQDN, out _); - } - - internal bool GetDNSInfo(string FQDN, out SQLDNSInfo result) - { - return _dnsInfoCache.TryGetValue(FQDN, out result); - } - - internal bool IsDuplicate(SQLDNSInfo newItem) - { - if (null != newItem) - { - SQLDNSInfo oldItem; - if (GetDNSInfo(newItem.FQDN, out oldItem)) - { - return (newItem.CachedIPv4Address == oldItem.CachedIPv4Address && - newItem.CachedIPv6Address == oldItem.CachedIPv6Address && - newItem.Port == oldItem.Port); - } - } - - return false; - } - } - - internal sealed class SQLDNSInfo - { - public string FQDN { get; set; } - public IPAddress CachedIPv4Address { get; set; } - public IPAddress CachedIPv6Address { get; set; } - public int Port { get; set; } - public IPAddress[] SpeculativeIPAddresses { get; set; } - - internal SQLDNSInfo(string fqdn, IPAddress ipv4, IPAddress ipv6, int port) - { - FQDN = fqdn; - CachedIPv4Address = ipv4; - CachedIPv6Address = ipv6; - Port = port; - } - - internal SQLDNSInfo(string fqdn, IPAddress[] speculativeIPAddresses) - { - FQDN = fqdn; - SpeculativeIPAddresses = speculativeIPAddresses; - } - } -} From 448ac2e3a6e13d7313cfd643f2554604e82183f4 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Sun, 7 Apr 2024 08:26:38 +0100 Subject: [PATCH 15/20] Update following test failure. SSRP timeout returns an OperationCanceledException inside an AggregateException --- .../netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 73d975028d..6ad565c193 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -492,7 +492,7 @@ private static Task SendUDPRequest(IPEndPoint endPoint, UdpClient client } } - catch (OperationCanceledException) + catch (AggregateException ae) when (ae.InnerException is OperationCanceledException) { responsePacket = null; } From 491b37128002233a1a88aaa445983071ab39c6ac Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Mon, 8 Apr 2024 06:56:51 +0100 Subject: [PATCH 16/20] Changes following test failures. One pre-existing issue with a resource string missing a space. Several instances where the SQL DNS cache property names were hardcoded and accessed via reflection. --- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 2 +- .../netfx/src/Microsoft.Data.SqlClient.csproj | 41 +++++++++++++++---- .../src/Resources/Strings.Designer.cs | 2 +- .../src/Resources/Strings.es.resx | 4 +- .../src/Resources/Strings.it.resx | 4 +- .../src/Resources/Strings.ko.resx | 4 +- .../src/Resources/Strings.pt-BR.resx | 4 +- .../src/Resources/Strings.resx | 4 +- .../src/Resources/Strings.ru.resx | 4 +- .../SystemDataInternals/ConnectionHelper.cs | 14 +++---- .../ConfigurableIpPreferenceTest.cs | 6 +-- 11 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index d1eba96d81..fcdd50ec8c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -251,7 +251,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _dataBytesLeft = (int)currentHeader.Length; _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)currentHeader.Length); #if DEBUG - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: _currentHeader?.length); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: currentHeader.Length); #endif } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 0a0757731b..b04f92e935 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -22,14 +22,14 @@ full - + $([System.IO.Path]::Combine('$(IntermediateOutputPath)','$(GeneratedSourceFileName)')) - + True @@ -689,21 +689,48 @@ Resources\StringsHelper.cs - + Resources\Strings.Designer.cs True True Strings.resx - + Resources\Strings.resx Microsoft.Data.SqlClient.Resources.Strings.resources System ResXFileCodeGenerator Strings.Designer.cs - - Resources\%(RecursiveDir)%(Filename)%(Extension) + + Resources\Strings.de.resx + + + Resources\Strings.es.resx + + + Resources\Strings.fr.resx + + + Resources\Strings.it.resx + + + Resources\Strings.ja.resx + + + Resources\Strings.ko.resx + + + Resources\Strings.pt-BR.resx + + + Resources\Strings.ru.resx + + + Resources\Strings.zh-Hans.resx + + + Resources\Strings.zh-Hant.resx Microsoft.Data.SqlClient.SqlMetaData.xml @@ -751,4 +778,4 @@ - + \ No newline at end of file diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs index a99e5d0303..45eaf3d718 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs @@ -12025,7 +12025,7 @@ internal class Strings { } /// - /// Looks up a localized string similar to Specified type is not registered on the target server.{0}.. + /// Looks up a localized string similar to Specified type is not registered on the target server. {0}.. /// internal static string SQLUDT_InvalidSqlType { get { diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx index 821f8f1aa5..bc65877f2e 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.es.resx @@ -3559,7 +3559,7 @@ La propiedad UdtTypeName debe establecerse sólo para los parámetros UDT. - El tipo especificado no está registrado en el servidor de destino.{0}. + El tipo especificado no está registrado en el servidor de destino. {0}. No se permiten parámetros UDT en la cláusula where, salvo que formen parte de la clave principal. @@ -4740,4 +4740,4 @@ El certificado no está disponible al validar el certificado. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx index add941c9ce..9499ee5478 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.it.resx @@ -3559,7 +3559,7 @@ È necessario impostare la proprietà UdtTypeName property solo per i parametri UDT. - I tipo specificato non è registrato sul server di destinazione.{0}. + I tipo specificato non è registrato sul server di destinazione. {0}. Parametri UDT non consentiti nella clausola WHERE a meno che non facciano parte della chiave primaria. @@ -4740,4 +4740,4 @@ Certificato non disponibile durante la convalida del certificato. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx index d3d62df164..d38de35c97 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ko.resx @@ -3559,7 +3559,7 @@ UdtTypeName 속성은 UDT 매개 변수에 대해서만 설정해야 합니다. - 지정한 유형이 대상 서버에 등록되어 있지 않습니다.{0}. + 지정한 유형이 대상 서버에 등록되어 있지 않습니다. {0}. 기본 키의 일부가 아닌 경우 UDT 매개 변수는 where 절에서 허용되지 않습니다. @@ -4740,4 +4740,4 @@ 인증서의 유효성을 검사하는 동안에는 인증서를 사용할 수 없습니다. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx index 3cb1ae77ec..a807de587f 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.pt-BR.resx @@ -3559,7 +3559,7 @@ A propriedade UdtTypeName deve ser definida apenas como parâmetros UDT. - O tipo especificado não está registrado no servidor de destino.{0}. + O tipo especificado não está registrado no servidor de destino. {0}. Os parâmetros UDT não são permitidos na cláusula where, a menos que façam parte da chave primária. @@ -4740,4 +4740,4 @@ Certificado não disponível durante a validação do certificado. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx index 3712ece88f..993cbdb8d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx @@ -3559,7 +3559,7 @@ UdtTypeName property must be set only for UDT parameters. - Specified type is not registered on the target server.{0}. + Specified type is not registered on the target server. {0}. UDT parameters not permitted in the where clause unless part of the primary key. @@ -4740,4 +4740,4 @@ Certificate not available while validating the certificate. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx index 24fb60a6e4..0fff1724af 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.ru.resx @@ -3559,7 +3559,7 @@ Свойство UdtTypeName должно быть установлено только для параметров UDT. - Указанный тип не зарегистрирован на сервере назначения.{0}. + Указанный тип не зарегистрирован на сервере назначения. {0}. Не допускается использование параметров UDT в составе конструкции Where, кроме случаев, когда они являются частью первичного ключа. @@ -4740,4 +4740,4 @@ Сертификат недоступен при проверке сертификата. - \ No newline at end of file + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs index 4f83a8aeb7..8e5d292aac 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs @@ -29,8 +29,8 @@ internal static class ConnectionHelper private static FieldInfo s_enforcedTimeoutDelayInMilliSeconds = s_tdsParserStateObject.GetField("_enforcedTimeoutDelayInMilliSeconds", BindingFlags.Instance | BindingFlags.NonPublic); private static FieldInfo s_pendingSQLDNSObject = s_sqlInternalConnectionTds.GetField("pendingSQLDNSObject", BindingFlags.Instance | BindingFlags.NonPublic); private static PropertyInfo s_pendingSQLDNS_FQDN = s_SQLDNSInfo.GetProperty("FQDN", BindingFlags.Instance | BindingFlags.Public); - private static PropertyInfo s_pendingSQLDNS_AddrIPv4 = s_SQLDNSInfo.GetProperty("AddrIPv4", BindingFlags.Instance | BindingFlags.Public); - private static PropertyInfo s_pendingSQLDNS_AddrIPv6 = s_SQLDNSInfo.GetProperty("AddrIPv6", BindingFlags.Instance | BindingFlags.Public); + private static PropertyInfo s_pendingSQLDNS_AddrIPv4 = s_SQLDNSInfo.GetProperty("CachedIPv4Address", BindingFlags.Instance | BindingFlags.Public); + private static PropertyInfo s_pendingSQLDNS_AddrIPv6 = s_SQLDNSInfo.GetProperty("CachedIPv6Address", BindingFlags.Instance | BindingFlags.Public); private static PropertyInfo s_pendingSQLDNS_Port = s_SQLDNSInfo.GetProperty("Port", BindingFlags.Instance | BindingFlags.Public); private static PropertyInfo dbConnectionInternalIsTransRoot = s_dbConnectionInternal.GetProperty("IsTransactionRoot", BindingFlags.Instance | BindingFlags.NonPublic); private static PropertyInfo dbConnectionInternalEnlistedTrans = s_sqlInternalConnection.GetProperty("EnlistedTransaction", BindingFlags.Instance | BindingFlags.NonPublic); @@ -112,16 +112,16 @@ public static void SetEnforcedTimeout(this SqlConnection connection, bool enforc /// /// Active connection to extract the requested data /// FQDN, AddrIPv4, AddrIPv6, and Port in sequence - public static Tuple GetSQLDNSInfo(this SqlConnection connection) + public static Tuple GetSQLDNSInfo(this SqlConnection connection) { object internalConnection = GetInternalConnection(connection); VerifyObjectIsInternalConnection(internalConnection); object pendingSQLDNSInfo = s_pendingSQLDNSObject.GetValue(internalConnection); string fqdn = s_pendingSQLDNS_FQDN.GetValue(pendingSQLDNSInfo) as string; - string ipv4 = s_pendingSQLDNS_AddrIPv4.GetValue(pendingSQLDNSInfo) as string; - string ipv6 = s_pendingSQLDNS_AddrIPv6.GetValue(pendingSQLDNSInfo) as string; - string port = s_pendingSQLDNS_Port.GetValue(pendingSQLDNSInfo) as string; - return new Tuple(fqdn, ipv4, ipv6, port); + System.Net.IPAddress ipv4 = s_pendingSQLDNS_AddrIPv4.GetValue(pendingSQLDNSInfo) as System.Net.IPAddress; + System.Net.IPAddress ipv6 = s_pendingSQLDNS_AddrIPv6.GetValue(pendingSQLDNSInfo) as System.Net.IPAddress; + int port = (int)s_pendingSQLDNS_Port.GetValue(pendingSQLDNSInfo); + return new Tuple(fqdn, ipv4, ipv6, port); } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs index 8003660889..d32d90066b 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConfigurableIpPreferenceTest/ConfigurableIpPreferenceTest.cs @@ -57,7 +57,7 @@ public void ConfigurableIpPreference(string ipPreference) { connection.Open(); Assert.Equal(ConnectionState.Open, connection.State); - Tuple DNSInfo = connection.GetSQLDNSInfo(); + Tuple DNSInfo = connection.GetSQLDNSInfo(); if(ipPreference == CnnPrefIPv4) { Assert.NotNull(DNSInfo.Item2); //IPv4 @@ -95,8 +95,8 @@ private void TestCachedConfigurableIpPreference(string ipPreference, string cnnS SQLFallbackDNSCacheGetDNSInfo.Invoke(SQLFallbackDNSCacheInstance, parameters); var dnsCacheEntry = parameters[1]; - const string AddrIPv4Property = "AddrIPv4"; - const string AddrIPv6Property = "AddrIPv6"; + const string AddrIPv4Property = "CachedIPv6Address"; + const string AddrIPv6Property = "CachedIPv6Address"; const string FQDNProperty = "FQDN"; Assert.NotNull(dnsCacheEntry); From 42caf30197d7f82ff79337f1b92bd0f3704797b9 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Mon, 8 Apr 2024 07:32:26 +0100 Subject: [PATCH 17/20] Reverted one change to SNIMarsConnection. The current header is now an instance-level variable once again. Not sure why this would make a difference, but without this, some of the data bytes are left in the packet buffer and interpreted as the header bytes of the next packet. --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 4 +- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 40 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index d5859731fa..d8ab44f5ce 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -46,7 +46,7 @@ internal enum SNIProviders /// /// SMUX packet header /// - internal ref struct SNISMUXHeader + internal struct SNISMUXHeader { public const int HEADER_LENGTH = 16; @@ -77,7 +77,7 @@ public void Read(Span bytes) Highwater = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(12)); } - public void Write(Span bytes) + public readonly void Write(Span bytes) { // access the highest element first to cause the largest range check in the jit, then fill in the rest of the value and carry on as normal BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(12), Highwater); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index fcdd50ec8c..ce345cc16a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -23,6 +23,7 @@ internal class SNIMarsConnection private int _currentHeaderByteCount; private int _dataBytesLeft; private SNIPacket _currentPacket; + private SNISMUXHeader _currentHeader; /// /// Connection ID @@ -43,6 +44,7 @@ public SNIMarsConnection(SNIHandle lowerHandle) _connectionId = Guid.NewGuid(); _sessions = new Dictionary(); _headerBytes = new byte[SNISMUXHeader.HEADER_LENGTH]; + _currentHeader = new SNISMUXHeader(); _nextSessionId = 0; _currentHeaderByteCount = 0; _dataBytesLeft = 0; @@ -202,7 +204,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { using (TrySNIEventScope.Create(nameof(SNIMarsConnection))) { - SNISMUXHeader currentHeader = default; SNIPacket currentPacket = null; SNIMarsHandle currentSession = null; @@ -222,7 +223,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { if (_currentHeaderByteCount != SNISMUXHeader.HEADER_LENGTH) { - currentHeader = default; currentPacket = null; currentSession = null; @@ -247,9 +247,9 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) } } - currentHeader.Read(_headerBytes); - _dataBytesLeft = (int)currentHeader.Length; - _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)currentHeader.Length); + _currentHeader.Read(_headerBytes); + _dataBytesLeft = (int)_currentHeader.Length; + _currentPacket = _lowerHandle.RentPacket(headerSize: 0, dataSize: (int)_currentHeader.Length); #if DEBUG SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, _dataBytesLeft {1}, _currentPacket {2}, Reading data of length: _currentHeader.length {3}", args0: _lowerHandle?.ConnectionId, args1: _dataBytesLeft, args2: currentPacket?._id, args3: currentHeader.Length); #endif @@ -257,7 +257,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) currentPacket = _currentPacket; - if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { if (_dataBytesLeft > 0) { @@ -283,44 +283,44 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _currentHeaderByteCount = 0; - if (!_sessions.ContainsKey(currentHeader.SessionId)) + if (!_sessions.ContainsKey(_currentHeader.SessionId)) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.SMUX_PROV, 0, SNICommon.InvalidParameterError, Strings.SNI_ERROR_5); HandleReceiveError(packet); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: currentHeader.SessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "Current Header Session Id {0} not found, MARS Session Id {1} will be destroyed, New SNI error created: {2}", args0: _currentHeader.SessionId, args1: _lowerHandle?.ConnectionId, args2: sniErrorCode); _lowerHandle.Dispose(); _lowerHandle = null; return; } - if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_FIN) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_FIN) { - _sessions.Remove(currentHeader.SessionId); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); + _sessions.Remove(_currentHeader.SessionId); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_FIN | MARS Session Id {0}, SMUX_FIN flag received, Current Header Session Id {1} removed", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } else { - currentSession = _sessions[currentHeader.SessionId]; - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); + currentSession = _sessions[_currentHeader.SessionId]; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Current Session assigned to Session Id {1}", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } } - if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_DATA) { - currentSession.HandleReceiveComplete(currentPacket, in currentHeader); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); + currentSession.HandleReceiveComplete(currentPacket, in _currentHeader); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_DATA | MARS Session Id {0}, Current Session {1} completed receiving Data", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } - if (currentHeader.Flags == (byte)SNISMUXFlags.SMUX_ACK) + if (_currentHeader.Flags == (byte)SNISMUXFlags.SMUX_ACK) { try { - currentSession.HandleAck(currentHeader.Highwater); - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: currentHeader.SessionId); + currentSession.HandleAck(_currentHeader.Highwater); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "SMUX_ACK | MARS Session Id {0}, Current Session {1} handled ack", args0: _lowerHandle?.ConnectionId, args1: _currentHeader.SessionId); } catch (Exception e) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: currentHeader.SessionId, args1: e?.Message); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "SMUX_ACK | MARS Session Id {0}, Exception occurred: {2}", args0: _currentHeader.SessionId, args1: e?.Message); SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, SNICommon.InternalExceptionError, e); } #if DEBUG From ea8f917dec35bf5b5be4d977eb58badc4fe6f0c4 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Mon, 8 Apr 2024 07:35:06 +0100 Subject: [PATCH 18/20] Small diff cleanup --- .../src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index ce345cc16a..05dee402af 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -17,13 +17,13 @@ internal class SNIMarsConnection private readonly Guid _connectionId; private readonly Dictionary _sessions; private readonly byte[] _headerBytes; + private readonly SNISMUXHeader _currentHeader; private readonly object _sync; private SNIHandle _lowerHandle; private ushort _nextSessionId; private int _currentHeaderByteCount; private int _dataBytesLeft; private SNIPacket _currentPacket; - private SNISMUXHeader _currentHeader; /// /// Connection ID From 21cf8340e05d3db3c5c344c025fa44daaa448ae4 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 11 Apr 2024 05:46:26 +0100 Subject: [PATCH 19/20] Fixing resource string paths --- .../netfx/src/Microsoft.Data.SqlClient.csproj | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index b04f92e935..16dc8713be 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -689,48 +689,21 @@ Resources\StringsHelper.cs - + Resources\Strings.Designer.cs True True Strings.resx - + Resources\Strings.resx Microsoft.Data.SqlClient.Resources.Strings.resources System ResXFileCodeGenerator Strings.Designer.cs - - Resources\Strings.de.resx - - - Resources\Strings.es.resx - - - Resources\Strings.fr.resx - - - Resources\Strings.it.resx - - - Resources\Strings.ja.resx - - - Resources\Strings.ko.resx - - - Resources\Strings.pt-BR.resx - - - Resources\Strings.ru.resx - - - Resources\Strings.zh-Hans.resx - - - Resources\Strings.zh-Hant.resx + + Resources\%(RecursiveDir)%(Filename)%(Extension) Microsoft.Data.SqlClient.SqlMetaData.xml @@ -778,4 +751,4 @@ - \ No newline at end of file + From fb464a01e299d2ba963956723e11c03a62f13860 Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Fri, 12 Apr 2024 07:49:56 +0100 Subject: [PATCH 20/20] Reducing GC impact in ParallelConnectAsync for SQL Servers with a single IP address. Correcting a deadlock in MARS connection tests. --- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 2 +- .../Data/SqlClient/SNI/SNITcpHandle.cs | 34 +++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index 05dee402af..4e30c0d687 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -17,7 +17,7 @@ internal class SNIMarsConnection private readonly Guid _connectionId; private readonly Dictionary _sessions; private readonly byte[] _headerBytes; - private readonly SNISMUXHeader _currentHeader; + private SNISMUXHeader _currentHeader; private readonly object _sync; private SNIHandle _lowerHandle; private ushort _nextSessionId; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 202e0cc0c4..fb7249b390 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -552,6 +552,8 @@ private static async Task ParallelConnectAsync(IPAddress[] serverAddress Exception lastException = null; IPEndPoint ipEndPoint = null; + Task lastTask = null; + List emptySocketList = new List(); List socketErrorCheckList = new List(1); Dictionary socketConnectionTasks = new(serverAddresses.Length); @@ -575,10 +577,11 @@ private static async Task ParallelConnectAsync(IPAddress[] serverAddress try { #if NET6_0_OR_GREATER - socketConnectionTasks.Add(socket.ConnectAsync(ipEndPoint, connectCancellationTokenSource.Token).AsTask(), socket); + lastTask = socket.ConnectAsync(ipEndPoint, connectCancellationTokenSource.Token).AsTask(); #else - socketConnectionTasks.Add(socket.ConnectAsync(ipEndPoint), socket); + lastTask = socket.ConnectAsync(ipEndPoint); #endif + socketConnectionTasks.Add(lastTask, socket); } catch (Exception e) { @@ -590,7 +593,32 @@ private static async Task ParallelConnectAsync(IPAddress[] serverAddress { while (socketConnectionTasks.Count > 0) { - Task completedTask = await Task.WhenAny(socketConnectionTasks.Keys).ConfigureAwait(false); + Task completedTask; + + // If there's only one IP address, we can avoid the implicit Task allocation of Task.WhenAny + if (socketConnectionTasks.Count == 1) + { + completedTask = lastTask; + + try + { + if (completedTask.Status != TaskStatus.Faulted) + { + await completedTask.ConfigureAwait(false); + } + } + catch (Exception connectException) + { + // This exception is silently swallowed here, but is thrown later in the method + SqlClientEventSource.Log.TryAdvancedTraceEvent( + $"{nameof(SNITCPHandle)}.{nameof(ParallelConnectAsync)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {connectException}"); + } + } + else + { + completedTask = await Task.WhenAny(socketConnectionTasks.Keys).ConfigureAwait(false); + } + Socket taskSocket = socketConnectionTasks[completedTask]; if (completedTask.Status == TaskStatus.RanToCompletion)