Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance and async improvements in SSRP, connection establishment and MARS connections #2451

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -387,9 +387,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan

SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info();
native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port;
native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString();
native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString();
native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port == 0 ? null : cachedDNSInfo?.Port.ToString();

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo);
}
Expand All @@ -399,7 +399,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
string constring,
ref IntPtr pConn,
byte[] spnBuffer,
byte[] instanceName,
Span<byte> instanceName,
bool fOverrideCache,
bool fSync,
int timeout,
Expand All @@ -409,7 +409,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
string hostNameInCertificate)
{

fixed (byte* pin_instanceName = &instanceName[0])
fixed (byte* pin_instanceName = instanceName)
{
SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO();

Expand All @@ -432,9 +432,9 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan

clientConsumerInfo.ipAddressPreference = ipPreference;
clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.CachedIPv4Address?.ToString();
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.CachedIPv6Address?.ToString();
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port == null ? null : cachedDNSInfo.Port.ToString();

if (spnBuffer != null)
{
Expand Down
Expand Up @@ -13,6 +13,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.ProviderBase;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -45,64 +46,58 @@ internal enum SNIProviders
/// <summary>
/// SMUX packet header
/// </summary>
internal sealed class SNISMUXHeader
internal struct SNISMUXHeader
{
public const int HEADER_LENGTH = 16;

public byte SMID;
public byte flags;
public ushort sessionId;
public uint length;
public uint sequenceNumber;
public uint highwater;
public byte Flags;
public ushort SessionId;
public uint Length;
public uint SequenceNumber;
public uint Highwater;

public void Read(byte[] bytes)
public SNISMUXHeader(byte flags, ushort sessionId, uint length, uint sequenceNumber, uint highwater)
{
SMID = bytes[0];
flags = bytes[1];
Span<byte> span = bytes.AsSpan();
sessionId = BinaryPrimitives.ReadUInt16LittleEndian(span.Slice(2));
length = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(4)) - SNISMUXHeader.HEADER_LENGTH;
sequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(8));
highwater = BinaryPrimitives.ReadUInt32LittleEndian(span.Slice(12));
Flags = flags;
SessionId = sessionId;
Length = length;
SequenceNumber = sequenceNumber;
Highwater = highwater;
}

public void Write(Span<byte> bytes)
public void Read(Span<byte> bytes)
{
// As per the MC-SMP spec, the first byte of the header will always be 0x53
Debug.Assert(bytes[0] == 0x53, "First byte of the SNI SMUX header was not 0x53");

Flags = bytes[1];
SessionId = BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(2));
Length = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(4)) - SNISMUXHeader.HEADER_LENGTH;
SequenceNumber = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(8));
Highwater = BinaryPrimitives.ReadUInt32LittleEndian(bytes.Slice(12));
}

public readonly void Write(Span<byte> bytes)
{
uint value = highwater;
// access the highest element first to cause the largest range check in the jit, then fill in the rest of the value and carry on as normal
bytes[15] = (byte)((value >> 24) & 0xff);
bytes[12] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.highwater).CopyTo(headerBytes, 12);
bytes[13] = (byte)((value >> 8) & 0xff);
bytes[14] = (byte)((value >> 16) & 0xff);

bytes[0] = SMID; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0);
bytes[1] = flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1);

value = sessionId;
bytes[2] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sessionId).CopyTo(headerBytes, 2);
bytes[3] = (byte)((value >> 8) & 0xff);

value = length;
bytes[4] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.length).CopyTo(headerBytes, 4);
bytes[5] = (byte)((value >> 8) & 0xff);
bytes[6] = (byte)((value >> 16) & 0xff);
bytes[7] = (byte)((value >> 24) & 0xff);

value = sequenceNumber;
bytes[8] = (byte)(value & 0xff); // BitConverter.GetBytes(_currentHeader.sequenceNumber).CopyTo(headerBytes, 8);
bytes[9] = (byte)((value >> 8) & 0xff);
bytes[10] = (byte)((value >> 16) & 0xff);
bytes[11] = (byte)((value >> 24) & 0xff);
BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(12), Highwater);

bytes[0] = 0x53; // BitConverter.GetBytes(_currentHeader.SMID).CopyTo(headerBytes, 0);
bytes[1] = Flags; // BitConverter.GetBytes(_currentHeader.flags).CopyTo(headerBytes, 1);

BinaryPrimitives.WriteUInt16LittleEndian(bytes.Slice(2), SessionId);

BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(4), Length);

BinaryPrimitives.WriteUInt32LittleEndian(bytes.Slice(8), SequenceNumber);

}
}

/// <summary>
/// SMUX packet flags
/// </summary>
[Flags]
internal enum SNISMUXFlags
internal enum SNISMUXFlags : byte
{
SMUX_SYN = 1, // Begin SMUX connection
SMUX_ACK = 2, // Acknowledge SMUX packets
Expand Down Expand Up @@ -332,7 +327,15 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout)
/// <summary>
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
/// </summary>
/// <exception cref="ArgumentOutOfRangeException">Thrown when ipPreference is not supported</exception>
#if NET6_0_OR_GREATER
internal static async ValueTask<IPAddress[]> GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async)
#else
internal static ValueTask<IPAddress[]> GetDnsIpAddresses(string serverName, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, bool async)
#endif
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
Expand All @@ -342,20 +345,92 @@ internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer ti
args0: serverName,
args1: remainingTimeout);
using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout);

#if NET6_0_OR_GREATER
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName, cts.Token);

if (async)
{
return SortIpAddressesByPreference(await task.ConfigureAwait(false), ipPreference);
}
else
{
task.Wait();
return SortIpAddressesByPreference(task.Result, ipPreference);
}
#else
// using this overload to support netstandard
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName);
task.ConfigureAwait(false);

task.Wait(cts.Token);
return task.Result;
return new ValueTask<IPAddress[]>(SortIpAddressesByPreference(task.Result, ipPreference));
#endif
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
/// <summary>
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
/// </summary>
/// <exception cref="ArgumentOutOfRangeException">Thrown when ipPreference is not supported</exception>
internal static async ValueTask<IPAddress[]> GetDnsIpAddresses(string serverName, SqlConnectionIPAddressPreference ipPreference, bool async)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName);
return Dns.GetHostAddresses(serverName);

return SortIpAddressesByPreference(async
? await Dns.GetHostAddressesAsync(serverName)
: Dns.GetHostAddresses(serverName),
ipPreference);
}
}

private static IPAddress[] SortIpAddressesByPreference(IPAddress[] dnsIPAddresses, SqlConnectionIPAddressPreference ipPreference)
{
AddressFamily? prioritiesFamily = ipPreference switch
{
SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork,
SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6,
SqlConnectionIPAddressPreference.UsePlatformDefault => null,
_ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(SortIpAddressesByPreference))
};

if (prioritiesFamily == null)
{
return dnsIPAddresses;
}
else
{
int resultArrayIndex = 0;
IPAddress[] ipAddresses = new IPAddress[dnsIPAddresses.Length];

// Return addresses of the preferred family first
for (int i = 0; i < dnsIPAddresses.Length; i++)
{
if (dnsIPAddresses[i].AddressFamily == prioritiesFamily)
{
ipAddresses[resultArrayIndex++] = dnsIPAddresses[i];
}
}

// Return addresses of the other family
for (int i = 0; i < dnsIPAddresses.Length; i++)
{
if (dnsIPAddresses[i].AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6
&& dnsIPAddresses[i].AddressFamily != prioritiesFamily)
{
ipAddresses[resultArrayIndex++] = dnsIPAddresses[i];
}
}

// If the DNS resolution returned records of types other than A and AAAA, the original array size will be
// too large, and must thus be resized. This is very unlikely, so we only try to do this post-hoc.
if (resultArrayIndex + 1 < ipAddresses.Length)
{
Array.Resize(ref ipAddresses, resultArrayIndex + 1);
}

return ipAddresses;
}
}

Expand Down