diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index f575d34900..94cb627d09 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -59,6 +59,9 @@ + + + Address diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SslOverTdsStreamTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SslOverTdsStreamTest.cs new file mode 100644 index 0000000000..56cc653744 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SslOverTdsStreamTest.cs @@ -0,0 +1,361 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Data.SqlClient.Tests +{ + public static class SslOverTdsStreamTest + { + public static TheoryData PacketSizes + { + get + { + const int EncapsulatedPacketCount = 4; + const int PassThroughPacketCount = 5; + + TheoryData data = new TheoryData(); + + data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 0); + data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 2); + data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 128); + data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 2048); + data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 8192); + + return data; + } + } + + + [Theory] + [MemberData(nameof(PacketSizes))] + public static void SyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength) + { + byte[] input; + byte[] output; + SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output); + + byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount, + (Stream stream, int index) => + { + stream.Write(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE); + } + ); + + ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output, + (Stream stream, byte[] bytes, int offset, int count) => + { + return stream.Read(bytes, offset, count); + } + ); + + Validate(input, output); + } + + [Theory] + [MemberData(nameof(PacketSizes))] + public static void AsyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength) + { + byte[] input; + byte[] output; + SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output); + byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount, + async (Stream stream, int index) => + { + await stream.WriteAsync(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE); + } + ); + + ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output, + async (Stream stream, byte[] bytes, int offset, int count) => + { + return await stream.ReadAsync(bytes, offset, count); + } + ); + + Validate(input, output); + } + + [Theory] + [MemberData(nameof(PacketSizes))] + public static void SyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength) + { + byte[] input; + byte[] output; + SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output); + + byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount, + (Stream stream, int index) => + { + stream.Write(input.AsSpan(TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE)); + } + ); + + ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output, + (Stream stream, byte[] bytes, int offset, int count) => + { + return stream.Read(bytes.AsSpan(offset, count)); + } + ); + + Validate(input, output); + } + + [Theory] + [MemberData(nameof(PacketSizes))] + public static void AsyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength) + { + byte[] input; + byte[] output; + SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output); + + byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount, + async (Stream stream, int index) => + { + await stream.WriteAsync( + new ReadOnlyMemory(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE) + ); + } + ); + + ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output, + async (Stream stream, byte[] bytes, int offset, int count) => + { + return await stream.ReadAsync( + new Memory(bytes, offset, count) + ); + } + ); + + Validate(input, output); + } + + + private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func> action) + { + using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength)) + using (Stream tdsStream = CreateSslOverTdsStream(stream)) + { + int offset = 0; + byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE]; + for (int index = 0; index < encapsulatedPacketCount; index++) + { + Array.Clear(bytes, 0, bytes.Length); + int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult(); + Array.Copy(bytes, 0, output, offset, packetBytes); + offset += packetBytes; + } + InvokeFinishHandshake(tdsStream); + for (int index = 0; index < passthroughPacketCount; index++) + { + Array.Clear(bytes, 0, bytes.Length); + int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult(); + Array.Copy(bytes, 0, output, offset, packetBytes); + offset += packetBytes; + } + } + } + + private static void InvokeFinishHandshake(Stream stream) + { + MethodInfo method = stream.GetType().GetMethod("FinishHandshake", BindingFlags.Public | BindingFlags.Instance); + method.Invoke(stream, null); + } + + private static Stream CreateSslOverTdsStream(Stream stream) + { + Type type = typeof(SqlClientFactory).Assembly.GetType("Microsoft.Data.SqlClient.SNI.SslOverTdsStream"); + ConstructorInfo ctor = type.GetConstructor(new Type[] { typeof(Stream) }); + Stream instance = (Stream)ctor.Invoke(new object[] { stream }); + return instance; + } + + private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func action) + { + using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength)) + using (Stream tdsStream = CreateSslOverTdsStream(stream)) + { + int offset = 0; + byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE]; + for (int index = 0; index < encapsulatedPacketCount; index++) + { + Array.Clear(bytes, 0, bytes.Length); + int packetBytes = ReadPacket(tdsStream, action, bytes); + Array.Copy(bytes, 0, output, offset, packetBytes); + offset += packetBytes; + } + InvokeFinishHandshake(tdsStream); + for (int index = 0; index < passthroughPacketCount; index++) + { + Array.Clear(bytes, 0, bytes.Length); + int packetBytes = ReadPacket(tdsStream, action, bytes); + Array.Copy(bytes, 0, output, offset, packetBytes); + offset += packetBytes; + } + } + } + + private static int ReadPacket(Stream tdsStream, Func action, byte[] output) + { + int readCount; + int offset = 0; + byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE]; + do + { + readCount = action(tdsStream, bytes, offset, bytes.Length - offset); + if (readCount > 0) + { + offset += readCount; + } + } + while (readCount > 0 && offset < bytes.Length); + Array.Copy(bytes, 0, output, 0, offset); + return offset; + } + + private static async Task ReadPacket(Stream tdsStream, Func> action, byte[] output) + { + int readCount; + int offset = 0; + byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE]; + do + { + readCount = await action(tdsStream, bytes, offset, bytes.Length - offset); + if (readCount > 0) + { + offset += readCount; + } + } + while (readCount > 0 && offset < bytes.Length); + Array.Copy(bytes, 0, output, 0, offset); + return offset; + } + + private static byte[] WritePackets(int encapsulatedPacketCount, int passthroughPacketCount, Action action) + { + byte[] buffer = null; + using (LimitedMemoryStream stream = new LimitedMemoryStream()) + { + using (Stream tdsStream = CreateSslOverTdsStream(stream)) + { + for (int index = 0; index < encapsulatedPacketCount; index++) + { + action(tdsStream, index); + } + InvokeFinishHandshake(tdsStream);//tdsStream.FinishHandshake(); + for (int index = 0; index < passthroughPacketCount; index++) + { + action(tdsStream, encapsulatedPacketCount + index); + } + } + buffer = stream.ToArray(); + } + return buffer; + } + + private static void SetupArrays(int packetCount, out byte[] input, out byte[] output) + { + byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 }; + input = new byte[packetCount * TdsEnums.DEFAULT_LOGIN_PACKET_SIZE]; + output = new byte[input.Length]; + for (int index = 0; index < packetCount; index++) + { + int position = 0; + while (position < TdsEnums.DEFAULT_LOGIN_PACKET_SIZE) + { + int copyCount = Math.Min(pattern.Length, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE - position); + Array.Copy(pattern, 0, input, (TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index) + position, copyCount); + position += copyCount; + } + } + } + + private static void Validate(byte[] input, byte[] output) + { + Assert.True(input.AsSpan().SequenceEqual(output.AsSpan())); + } + + internal static class TdsEnums + { + public const int DEFAULT_LOGIN_PACKET_SIZE = 4096; + } + } + + [DebuggerStepThrough] + public sealed partial class LimitedMemoryStream : MemoryStream + { + private readonly int _readLimit; + private readonly int _delay; + + public LimitedMemoryStream(int readLimit = 0, int delay = 0) + { + _readLimit = readLimit; + _delay = delay; + } + + public LimitedMemoryStream(byte[] buffer, int readLimit = 0, int delay = 0) + : base(buffer) + { + _readLimit = readLimit; + _delay = delay; + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_readLimit > 0) + { + return base.Read(buffer, offset, Math.Min(_readLimit, count)); + } + else + { + return base.Read(buffer, offset, count); + } + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_delay > 0) + { + await Task.Delay(_delay, cancellationToken); + } + if (_readLimit > 0) + { + return await base.ReadAsync(buffer, offset, Math.Min(_readLimit, count), cancellationToken).ConfigureAwait(false); + } + else + { + return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + } + public override int Read(Span destination) + { + if (_readLimit > 0) + { + return base.Read(destination.Slice(0, Math.Min(_readLimit, destination.Length))); + } + else + { + return base.Read(destination); + } + } + + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + if (_readLimit > 0) + { + return base.ReadAsync(destination.Slice(0, Math.Min(_readLimit, destination.Length)), cancellationToken); + } + else + { + return base.ReadAsync(destination, cancellationToken); + } + } + } +}