From 4d4ad15a6112cebcd1b9f445bddaaca94029f71d Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Sun, 26 Apr 2020 21:09:18 +0100 Subject: [PATCH 1/6] update SslOverTdsStream --- .../src/Microsoft.Data.SqlClient.csproj | 2 + .../SNI/SslOverTdsStream.NetCoreApp.cs | 289 ++++++++++++++++++ .../SNI/SslOverTdsStream.NetStandard.cs | 218 +++++++++++++ .../Data/SqlClient/SNI/SslOverTdsStream.cs | 244 ++------------- 4 files changed, 531 insertions(+), 222 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs create mode 100644 src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 29a0f61505..af266c927d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -227,6 +227,7 @@ + @@ -247,6 +248,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs new file mode 100644 index 0000000000..9e6ce6fef3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -0,0 +1,289 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.SNI +{ + internal sealed partial class SslOverTdsStream + { + public override int Read(byte[] buffer, int offset, int count) + { + return Read(buffer.AsSpan(offset, count)); + } + + public override void Write(byte[] buffer, int offset, int count) + { + Write(buffer.AsSpan(offset, count)); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + } + + public override int Read(Span buffer) + { + if (_encapsulate) + { + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); + int readCount = _stream.Read(buffer.Slice(0, wantedCount)); + if (readCount == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + Span headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead))); + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); + + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + + // read as much from the packet as the caller can accept + int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); + _packetBytes -= packetBytesRead; + return packetBytesRead; + } + } + else + { + return _stream.Read(buffer); + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (_encapsulate) + { + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); + + int readCount; + { + ValueTask remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); + if (remainderReadValueTask.IsCompletedSuccessfully) + { + readCount = remainderReadValueTask.Result; + } + else + { + readCount = await remainderReadValueTask.AsTask().ConfigureAwait(false); + } + } + if (readCount == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); + Array.Clear(headerBytes, 0, headerBytes.Length); + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration; + { + ValueTask headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); + if (headerReadValueTask.IsCompletedSuccessfully) + { + headerBytesReadIteration = headerReadValueTask.Result; + } + else + { + headerBytesReadIteration = await headerReadValueTask.AsTask().ConfigureAwait(false); + } + } + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, cleanup the rented array and then tell the caller + ArrayPool.Shared.Return(headerBytes, clearArray: true); + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); + + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + + ArrayPool.Shared.Return(headerBytes, clearArray: true); + + // read as much from the packet as the caller can accept + int packetBytesRead; + { + ValueTask packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); + if (packetReadValueTask.IsCompletedSuccessfully) + { + packetBytesRead = packetReadValueTask.Result; + } + else + { + packetBytesRead = await packetReadValueTask.AsTask().ConfigureAwait(false); + } + } + _packetBytes -= packetBytesRead; + return packetBytesRead; + } + } + else + { + int read; + { + ValueTask readValueTask = _stream.ReadAsync(buffer, cancellationToken); + if (readValueTask.IsCompletedSuccessfully) + { + read = readValueTask.Result; + } + else + { + read = await readValueTask.AsTask().ConfigureAwait(false); + } + } + return read; + } + } + + public override void Write(ReadOnlySpan buffer) + { + // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After + // negotiation, the underlying socket only sees SSL frames. + if (_encapsulate) + { + ReadOnlySpan remaining = buffer; + byte[] packetBuffer = null; + while (remaining.Length > 0) + { + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); + int packetLength = TdsEnums.HEADER_LEN + dataLength; + + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + + SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); + + Span data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength); + remaining.Slice(0, dataLength).CopyTo(data); + + _stream.Write(packetBuffer.AsSpan(0, packetLength)); + _stream.Flush(); + + remaining = remaining.Slice(dataLength); + } + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } + } + else + { + _stream.Write(buffer); + _stream.Flush(); + } + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (_encapsulate) + { + ReadOnlyMemory remaining = buffer; + byte[] packetBuffer = null; + while (remaining.Length > 0) + { + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); + int packetLength = TdsEnums.HEADER_LEN + dataLength; + + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + + SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); + + remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength)); + + { + ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory(packetBuffer, 0, packetLength), cancellationToken); + if (!packetWriteValueTask.IsCompletedSuccessfully) + { + await packetWriteValueTask.AsTask().ConfigureAwait(false); + } + } + + await _stream.FlushAsync().ConfigureAwait(false); + + + remaining = remaining.Slice(dataLength); + } + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } + } + else + { + { + ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken); + if (!valueTask.IsCompletedSuccessfully) + { + await valueTask.AsTask().ConfigureAwait(false); + } + } + Task flushTask = _stream.FlushAsync(); + if (flushTask.IsCompletedSuccessfully) + { + await flushTask.ConfigureAwait(false); + } + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs new file mode 100644 index 0000000000..0a5b3ffee3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs @@ -0,0 +1,218 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.SNI +{ + internal sealed partial class SslOverTdsStream : Stream + { + public override int Read(byte[] buffer, int offset, int count) + { + if (_encapsulate) + { + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(count, _packetBytes); + int readCount = _stream.Read(buffer, offset, wantedCount); + if (readCount == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); + Array.Clear(headerBytes, 0, headerBytes.Length); + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration = _stream.Read(headerBytes, headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)); + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, cleanup the rented array and then tell the caller + ArrayPool.Shared.Return(headerBytes, clearArray: true); + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); + + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + + ArrayPool.Shared.Return(headerBytes, clearArray: true); + + // read as much from the packet as the caller can accept + int packetBytesRead = _stream.Read(buffer, offset, Math.Min(count, _packetBytes)); + _packetBytes -= packetBytesRead; + return packetBytesRead; + } + } + else + { + return _stream.Read(buffer, offset, count); + } + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_encapsulate) + { + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(count, _packetBytes); + int readCount = await _stream.ReadAsync(buffer, offset, wantedCount, cancellationToken); + if (readCount == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); + Array.Clear(headerBytes, 0, headerBytes.Length); + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration = await _stream.ReadAsync(headerBytes, headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead), cancellationToken); + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, cleanup the rented array and then tell the caller + ArrayPool.Shared.Return(headerBytes, clearArray: true); + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); + + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + + ArrayPool.Shared.Return(headerBytes, clearArray: true); + + // read as much from the packet as the caller can accept + int packetBytesRead = await _stream.ReadAsync(buffer, offset, Math.Min(count, _packetBytes), cancellationToken); + _packetBytes -= packetBytesRead; + return packetBytesRead; + } + } + else + { + return await _stream.ReadAsync(buffer, offset, count, cancellationToken); + } + } + + public override void Write(byte[] buffer, int offset, int count) + { + // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After + // negotiation, the underlying socket only sees SSL frames. + if (_encapsulate) + { + int remainingBytes = count; + int dataOffset = offset; + byte[] packetBuffer = null; + while (remainingBytes > 0) + { + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remainingBytes); + int packetLength = TdsEnums.HEADER_LEN + dataLength; + remainingBytes -= dataLength; + + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + + SetupPreLoginPacketHeader(packetBuffer, dataLength, remainingBytes); + + Array.Copy(buffer, dataOffset, packetBuffer, TdsEnums.HEADER_LEN, dataLength); + + _stream.Write(packetBuffer, 0, packetLength); + _stream.Flush(); + + dataOffset += dataLength; + } + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } + } + else + { + _stream.Write(buffer, offset, count); + _stream.Flush(); + } + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_encapsulate) + { + int remainingBytes = count; + int dataOffset = offset; + byte[] packetBuffer = null; + while (remainingBytes > 0) + { + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remainingBytes); + int packetLength = TdsEnums.HEADER_LEN + dataLength; + remainingBytes -= dataLength; + + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + + SetupPreLoginPacketHeader(packetBuffer, dataLength, remainingBytes); + + Array.Copy(buffer, dataOffset, packetBuffer, TdsEnums.HEADER_LEN, dataLength); + + await _stream.WriteAsync(packetBuffer, 0, packetLength, cancellationToken).ConfigureAwait(false); + await _stream.FlushAsync().ConfigureAwait(false); + + dataOffset += dataLength; + } + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } + } + else + { + await _stream.WriteAsync(buffer, offset, count).ConfigureAwait(false); + Task flushTask = _stream.FlushAsync(); + if (flushTask.Status == TaskStatus.RanToCompletion) + { + await flushTask.ConfigureAwait(false); + } + } + } + + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs index 74a2e6226d..09ed5770fc 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -3,11 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using System.Buffers; using System.IO; using System.IO.Pipes; -using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Data.SqlClient.SNI { @@ -16,7 +13,7 @@ namespace Microsoft.Data.SqlClient.SNI /// transported in TDS packet type 0x12. Once SSL handshake has completed, SSL /// packets are sent transparently. /// - internal sealed class SslOverTdsStream : Stream + internal sealed partial class SslOverTdsStream : Stream { private readonly Stream _stream; @@ -39,200 +36,13 @@ public SslOverTdsStream(Stream stream) /// /// Finish SSL handshake. Stop encapsulating in TDS. /// - public void FinishHandshake() - { - _encapsulate = false; - } - - /// - /// Read buffer - /// - /// Buffer - /// Offset - /// Byte count - /// Bytes read - public override int Read(byte[] buffer, int offset, int count) => - ReadInternal(buffer, offset, count, CancellationToken.None, async: false).GetAwaiter().GetResult(); - - /// - /// Write Buffer - /// - /// - /// - /// - public override void Write(byte[] buffer, int offset, int count) - => WriteInternal(buffer, offset, count, CancellationToken.None, async: false).Wait(); - - /// - /// Write Buffer Asynchronously - /// - /// - /// - /// - /// - /// - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token) - => WriteInternal(buffer, offset, count, token, async: true); - - /// - /// Read Buffer Asynchronously - /// - /// - /// - /// - /// - /// - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token) - => ReadInternal(buffer, offset, count, token, async: true); - - /// - /// Read Internal is called synchronously when async is false - /// - private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) - { - int readBytes = 0; - byte[] packetData = null; - byte[] readTarget = buffer; - int readOffset = offset; - if (_encapsulate) - { - packetData = ArrayPool.Shared.Rent(count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count); - readTarget = packetData; - readOffset = 0; - if (_packetBytes == 0) - { - // Account for split packets - while (readBytes < TdsEnums.HEADER_LEN) - { - var readBytesForHeader = async ? - await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) : - _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes); - - if (readBytesForHeader == 0) - { - throw new EndOfStreamException("End of stream reached"); - } - - readBytes += readBytesForHeader; - } - - _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; - _packetBytes -= TdsEnums.HEADER_LEN; - } - - if (count > _packetBytes) - { - count = _packetBytes; - } - } - - readBytes = async ? - await _stream.ReadAsync(readTarget, readOffset, count, token).ConfigureAwait(false) : - _stream.Read(readTarget, readOffset, count); - - if (_encapsulate) - { - _packetBytes -= readBytes; - } - if (packetData != null) - { - Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); - ArrayPool.Shared.Return(packetData, clearArray: true); - } - return readBytes; - } - - /// - /// The internal write method calls Sync APIs when Async flag is false - /// - private async Task WriteInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) - { - int currentCount = 0; - int currentOffset = offset; - - while (count > 0) - { - // During the SSL negotiation phase, SSL is tunneled over TDS packet type 0x12. After - // negotiation, the underlying socket only sees SSL frames. - // - if (_encapsulate) - { - if (count > PACKET_SIZE_WITHOUT_HEADER) - { - currentCount = PACKET_SIZE_WITHOUT_HEADER; - } - else - { - currentCount = count; - } - - count -= currentCount; - - // Prepend buffer data with TDS prelogin header - int combinedLength = TdsEnums.HEADER_LEN + currentCount; - byte[] combinedBuffer = ArrayPool.Shared.Rent(combinedLength); - - // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a - // partial packet (whether or not count != 0). - combinedBuffer[7] = 0; // touch this first for the jit bounds check - combinedBuffer[0] = PRELOGIN_PACKET_TYPE; - combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); - combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); - combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100); - combinedBuffer[4] = 0; - combinedBuffer[5] = 0; - combinedBuffer[6] = 0; - - Array.Copy(buffer, currentOffset, combinedBuffer, TdsEnums.HEADER_LEN, (combinedLength - TdsEnums.HEADER_LEN)); - - if (async) - { - await _stream.WriteAsync(combinedBuffer, 0, combinedLength, token).ConfigureAwait(false); - } - else - { - _stream.Write(combinedBuffer, 0, combinedLength); - } - - Array.Clear(combinedBuffer, 0, combinedLength); - ArrayPool.Shared.Return(combinedBuffer); - } - else - { - currentCount = count; - count = 0; - - if (async) - { - await _stream.WriteAsync(buffer, currentOffset, currentCount, token).ConfigureAwait(false); - } - else - { - _stream.Write(buffer, currentOffset, currentCount); - } - } - - if (async) - { - await _stream.FlushAsync().ConfigureAwait(false); - } - else - { - _stream.Flush(); - } - - currentOffset += currentCount; - } - } + public void FinishHandshake() => _encapsulate = false; /// /// Set stream length. /// /// Length - public override void SetLength(long value) - { - throw new NotSupportedException(); - } + public override void SetLength(long value) => throw new NotSupportedException(); /// /// Flush stream @@ -252,14 +62,8 @@ public override void Flush() /// public override long Position { - get - { - throw new NotSupportedException(); - } - set - { - throw new NotSupportedException(); - } + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); } /// @@ -268,44 +72,40 @@ public override long Position /// Offset /// Origin /// Position - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); /// /// Check if stream can be read from /// - public override bool CanRead - { - get { return _stream.CanRead; } - } + public override bool CanRead => _stream.CanRead; /// /// Check if stream can be written to /// - public override bool CanWrite - { - get { return _stream.CanWrite; } - } + public override bool CanWrite => _stream.CanWrite; /// /// Check if stream can be seeked /// - public override bool CanSeek - { - get { return false; } // Seek not supported - } + public override bool CanSeek => false; /// /// Get stream length /// - public override long Length + public override long Length => throw new NotSupportedException(); + + private static void SetupPreLoginPacketHeader(byte[] buffer, int dataLength, int remainingLength) { - get - { - throw new NotSupportedException(); - } + // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a + // partial packet (whether or not count != 0). + buffer[7] = 0; // touch this first for the jit bounds check + buffer[0] = PRELOGIN_PACKET_TYPE; + buffer[1] = (byte)(remainingLength > 0 ? 0 : 1); + buffer[2] = (byte)((dataLength + TdsEnums.HEADER_LEN) / 0x100); + buffer[3] = (byte)((dataLength + TdsEnums.HEADER_LEN) % 0x100); + buffer[4] = 0; + buffer[5] = 0; + buffer[6] = 0; } } } From 8f501da5fab0c155c2a15c186b6edf47993ab569 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 27 Apr 2020 11:25:39 +0100 Subject: [PATCH 2/6] address feedback --- .../Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index 9e6ce6fef3..7a8522d235 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -114,7 +114,6 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation else { byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); - Array.Clear(headerBytes, 0, headerBytes.Length); // fetch the packet header to determine how long the packet is int headerBytesRead = 0; From 60052709b08bc63825e959a47219f58bc999512d Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Sat, 11 Jul 2020 01:03:23 +0100 Subject: [PATCH 3/6] await valuetasks directly, use finally blocks to ensure return of rental buffers and change to early exit code flow --- .../SNI/SslOverTdsStream.NetCoreApp.cs | 271 +++++++++--------- 1 file changed, 138 insertions(+), 133 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index 7a8522d235..12f1d9150e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -33,161 +33,165 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override int Read(Span buffer) { - if (_encapsulate) + if (!_encapsulate) { - if (_packetBytes > 0) + return _stream.Read(buffer); + } + + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); + int readCount = _stream.Read(buffer.Slice(0, wantedCount)); + if (readCount == 0) { - // there are queued bytes from a previous packet available - // work out how many of the remaining bytes we can consume - int wantedCount = Math.Min(buffer.Length, _packetBytes); - int readCount = _stream.Read(buffer.Slice(0, wantedCount)); - if (readCount == 0) + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + Span headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, TdsEnums.HEADER_LEN - headerBytesRead)); + if (headerBytesReadIteration == 0) { // 0 means the connection was closed, tell the caller return 0; } - _packetBytes -= readCount; - return readCount; - } - else - { - Span headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); - // fetch the packet header to determine how long the packet is - int headerBytesRead = 0; - do - { - int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead))); - if (headerBytesReadIteration == 0) - { - // 0 means the connection was closed, tell the caller - return 0; - } - headerBytesRead += headerBytesReadIteration; - } while (headerBytesRead < TdsEnums.HEADER_LEN); + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - - // read as much from the packet as the caller can accept - int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); - _packetBytes -= packetBytesRead; - return packetBytesRead; - } - } - else - { - return _stream.Read(buffer); + // read as much from the packet as the caller can accept + int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); + _packetBytes -= packetBytesRead; + return packetBytesRead; } + } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - if (_encapsulate) + if (!_encapsulate) { - if (_packetBytes > 0) + int read; { - // there are queued bytes from a previous packet available - // work out how many of the remaining bytes we can consume - int wantedCount = Math.Min(buffer.Length, _packetBytes); + ValueTask readValueTask = _stream.ReadAsync(buffer, cancellationToken); + if (readValueTask.IsCompletedSuccessfully) + { + read = readValueTask.Result; + } + else + { + read = await readValueTask.ConfigureAwait(false); + } + } + return read; + } + + if (_packetBytes > 0) + { + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); - int readCount; + int readCount; + { + ValueTask remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); + if (remainderReadValueTask.IsCompletedSuccessfully) { - ValueTask remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); - if (remainderReadValueTask.IsCompletedSuccessfully) - { - readCount = remainderReadValueTask.Result; - } - else - { - readCount = await remainderReadValueTask.AsTask().ConfigureAwait(false); - } + readCount = remainderReadValueTask.Result; } - if (readCount == 0) + else { - // 0 means the connection was closed, tell the caller - return 0; + readCount = await remainderReadValueTask.ConfigureAwait(false); } - _packetBytes -= readCount; - return readCount; } - else + if (readCount == 0) { - byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); + // 0 means the connection was closed, tell the caller + return 0; + } + _packetBytes -= readCount; + return readCount; + } + else + { + byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); - // fetch the packet header to determine how long the packet is - int headerBytesRead = 0; - do + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration; { - int headerBytesReadIteration; + ValueTask headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); + if (headerReadValueTask.IsCompletedSuccessfully) { - ValueTask headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); - if (headerReadValueTask.IsCompletedSuccessfully) - { - headerBytesReadIteration = headerReadValueTask.Result; - } - else - { - headerBytesReadIteration = await headerReadValueTask.AsTask().ConfigureAwait(false); - } + headerBytesReadIteration = headerReadValueTask.Result; } - if (headerBytesReadIteration == 0) + else { - // 0 means the connection was closed, cleanup the rented array and then tell the caller - ArrayPool.Shared.Return(headerBytes, clearArray: true); - return 0; + headerBytesReadIteration = await headerReadValueTask.ConfigureAwait(false); } - headerBytesRead += headerBytesReadIteration; - } while (headerBytesRead < TdsEnums.HEADER_LEN); + } + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, cleanup the rented array and then tell the caller + ArrayPool.Shared.Return(headerBytes, clearArray: true); + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); - // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - ArrayPool.Shared.Return(headerBytes, clearArray: true); + ArrayPool.Shared.Return(headerBytes, clearArray: true); - // read as much from the packet as the caller can accept - int packetBytesRead; - { - ValueTask packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); - if (packetReadValueTask.IsCompletedSuccessfully) - { - packetBytesRead = packetReadValueTask.Result; - } - else - { - packetBytesRead = await packetReadValueTask.AsTask().ConfigureAwait(false); - } - } - _packetBytes -= packetBytesRead; - return packetBytesRead; - } - } - else - { - int read; + // read as much from the packet as the caller can accept + int packetBytesRead; { - ValueTask readValueTask = _stream.ReadAsync(buffer, cancellationToken); - if (readValueTask.IsCompletedSuccessfully) + ValueTask packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); + if (packetReadValueTask.IsCompletedSuccessfully) { - read = readValueTask.Result; + packetBytesRead = packetReadValueTask.Result; } else { - read = await readValueTask.AsTask().ConfigureAwait(false); + packetBytesRead = await packetReadValueTask.ConfigureAwait(false); } } - return read; + _packetBytes -= packetBytesRead; + return packetBytesRead; } + } public override void Write(ReadOnlySpan buffer) { // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After // negotiation, the underlying socket only sees SSL frames. - if (_encapsulate) + if (!_encapsulate) + { + _stream.Write(buffer); + _stream.Flush(); + } + + ReadOnlySpan remaining = buffer; + byte[] packetBuffer = null; + try { - ReadOnlySpan remaining = buffer; - byte[] packetBuffer = null; while (remaining.Length > 0) { int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); @@ -213,24 +217,37 @@ public override void Write(ReadOnlySpan buffer) remaining = remaining.Slice(dataLength); } + } + finally + { if (packetBuffer != null) { ArrayPool.Shared.Return(packetBuffer, clearArray: true); } } - else - { - _stream.Write(buffer); - _stream.Flush(); - } } public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - if (_encapsulate) + if (!_encapsulate) + { + { + ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken); + if (!valueTask.IsCompletedSuccessfully) + { + await valueTask.ConfigureAwait(false); + } + } + Task flushTask = _stream.FlushAsync(); + if (flushTask.IsCompletedSuccessfully) + { + await flushTask.ConfigureAwait(false); + } + } + ReadOnlyMemory remaining = buffer; + byte[] packetBuffer = null; + try { - ReadOnlyMemory remaining = buffer; - byte[] packetBuffer = null; while (remaining.Length > 0) { int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); @@ -254,7 +271,7 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory(packetBuffer, 0, packetLength), cancellationToken); if (!packetWriteValueTask.IsCompletedSuccessfully) { - await packetWriteValueTask.AsTask().ConfigureAwait(false); + await packetWriteValueTask.ConfigureAwait(false); } } @@ -263,24 +280,12 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella remaining = remaining.Slice(dataLength); } - if (packetBuffer != null) - { - ArrayPool.Shared.Return(packetBuffer, clearArray: true); - } } - else + finally { + if (packetBuffer != null) { - ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken); - if (!valueTask.IsCompletedSuccessfully) - { - await valueTask.AsTask().ConfigureAwait(false); - } - } - Task flushTask = _stream.FlushAsync(); - if (flushTask.IsCompletedSuccessfully) - { - await flushTask.ConfigureAwait(false); + ArrayPool.Shared.Return(packetBuffer, clearArray: true); } } } From e8977c6e06e0d797fd96889ffb86ee7fe45d3244 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 16 Jul 2020 11:19:42 +0100 Subject: [PATCH 4/6] make sure to return from early encapsulate negative check blocks --- .../Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index 12f1d9150e..dc0e304931 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -186,6 +186,7 @@ public override void Write(ReadOnlySpan buffer) { _stream.Write(buffer); _stream.Flush(); + return; } ReadOnlySpan remaining = buffer; @@ -243,7 +244,9 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella { await flushTask.ConfigureAwait(false); } + return; } + ReadOnlyMemory remaining = buffer; byte[] packetBuffer = null; try From ef3267a5eff6efc2c9242363d514f8742170c9c2 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 16 Jul 2020 20:05:18 +0100 Subject: [PATCH 5/6] add sni trace scope class and add trace events to encapsulation mode calls in SslOverTdsStream --- .../SNI/SslOverTdsStream.NetCoreApp.cs | 295 +++++++++--------- .../SNI/SslOverTdsStream.NetStandard.cs | 60 ++-- .../Data/SqlClient/SNI/SslOverTdsStream.cs | 6 +- .../Data/SqlClient/SqlClientEventSource.cs | 21 ++ 4 files changed, 211 insertions(+), 171 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index dc0e304931..cb634cb6af 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -38,46 +38,48 @@ public override int Read(Span buffer) return _stream.Read(buffer); } - if (_packetBytes > 0) + using (SNIEventScope.Create(" reading encapsulated bytes")) { - // there are queued bytes from a previous packet available - // work out how many of the remaining bytes we can consume - int wantedCount = Math.Min(buffer.Length, _packetBytes); - int readCount = _stream.Read(buffer.Slice(0, wantedCount)); - if (readCount == 0) + if (_packetBytes > 0) { - // 0 means the connection was closed, tell the caller - return 0; - } - _packetBytes -= readCount; - return readCount; - } - else - { - Span headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; - - // fetch the packet header to determine how long the packet is - int headerBytesRead = 0; - do - { - int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, TdsEnums.HEADER_LEN - headerBytesRead)); - if (headerBytesReadIteration == 0) + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); + int readCount = _stream.Read(buffer.Slice(0, wantedCount)); + if (readCount == 0) { // 0 means the connection was closed, tell the caller return 0; } - headerBytesRead += headerBytesReadIteration; - } while (headerBytesRead < TdsEnums.HEADER_LEN); + _packetBytes -= readCount; + return readCount; + } + else + { + Span headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; + + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do + { + int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, TdsEnums.HEADER_LEN - headerBytesRead)); + if (headerBytesReadIteration == 0) + { + // 0 means the connection was closed, tell the caller + return 0; + } + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); - // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - // read as much from the packet as the caller can accept - int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); - _packetBytes -= packetBytesRead; - return packetBytesRead; + // read as much from the packet as the caller can accept + int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); + _packetBytes -= packetBytesRead; + return packetBytesRead; + } } - } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) @@ -98,84 +100,85 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } return read; } - - if (_packetBytes > 0) + using (SNIEventScope.Create(" reading encapsulated bytes")) { - // there are queued bytes from a previous packet available - // work out how many of the remaining bytes we can consume - int wantedCount = Math.Min(buffer.Length, _packetBytes); - - int readCount; + if (_packetBytes > 0) { - ValueTask remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); - if (remainderReadValueTask.IsCompletedSuccessfully) + // there are queued bytes from a previous packet available + // work out how many of the remaining bytes we can consume + int wantedCount = Math.Min(buffer.Length, _packetBytes); + + int readCount; { - readCount = remainderReadValueTask.Result; + ValueTask remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); + if (remainderReadValueTask.IsCompletedSuccessfully) + { + readCount = remainderReadValueTask.Result; + } + else + { + readCount = await remainderReadValueTask.ConfigureAwait(false); + } } - else + if (readCount == 0) { - readCount = await remainderReadValueTask.ConfigureAwait(false); + // 0 means the connection was closed, tell the caller + return 0; } + _packetBytes -= readCount; + return readCount; } - if (readCount == 0) + else { - // 0 means the connection was closed, tell the caller - return 0; - } - _packetBytes -= readCount; - return readCount; - } - else - { - byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); + byte[] headerBytes = ArrayPool.Shared.Rent(TdsEnums.HEADER_LEN); - // fetch the packet header to determine how long the packet is - int headerBytesRead = 0; - do - { - int headerBytesReadIteration; + // fetch the packet header to determine how long the packet is + int headerBytesRead = 0; + do { - ValueTask headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); - if (headerReadValueTask.IsCompletedSuccessfully) + int headerBytesReadIteration; { - headerBytesReadIteration = headerReadValueTask.Result; + ValueTask headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); + if (headerReadValueTask.IsCompletedSuccessfully) + { + headerBytesReadIteration = headerReadValueTask.Result; + } + else + { + headerBytesReadIteration = await headerReadValueTask.ConfigureAwait(false); + } } - else + if (headerBytesReadIteration == 0) { - headerBytesReadIteration = await headerReadValueTask.ConfigureAwait(false); + // 0 means the connection was closed, cleanup the rented array and then tell the caller + ArrayPool.Shared.Return(headerBytes, clearArray: true); + return 0; } - } - if (headerBytesReadIteration == 0) - { - // 0 means the connection was closed, cleanup the rented array and then tell the caller - ArrayPool.Shared.Return(headerBytes, clearArray: true); - return 0; - } - headerBytesRead += headerBytesReadIteration; - } while (headerBytesRead < TdsEnums.HEADER_LEN); + headerBytesRead += headerBytesReadIteration; + } while (headerBytesRead < TdsEnums.HEADER_LEN); - // read the packet data size from the header and store it in case it is needed for a subsequent call - _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + // read the packet data size from the header and store it in case it is needed for a subsequent call + _packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - ArrayPool.Shared.Return(headerBytes, clearArray: true); + ArrayPool.Shared.Return(headerBytes, clearArray: true); - // read as much from the packet as the caller can accept - int packetBytesRead; - { - ValueTask packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); - if (packetReadValueTask.IsCompletedSuccessfully) + // read as much from the packet as the caller can accept + int packetBytesRead; { - packetBytesRead = packetReadValueTask.Result; - } - else - { - packetBytesRead = await packetReadValueTask.ConfigureAwait(false); + ValueTask packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); + if (packetReadValueTask.IsCompletedSuccessfully) + { + packetBytesRead = packetReadValueTask.Result; + } + else + { + packetBytesRead = await packetReadValueTask.ConfigureAwait(false); + } } + _packetBytes -= packetBytesRead; + return packetBytesRead; } - _packetBytes -= packetBytesRead; - return packetBytesRead; } - } public override void Write(ReadOnlySpan buffer) @@ -189,41 +192,44 @@ public override void Write(ReadOnlySpan buffer) return; } - ReadOnlySpan remaining = buffer; - byte[] packetBuffer = null; - try + using (SNIEventScope.Create(" writing encapsulated bytes")) { - while (remaining.Length > 0) + ReadOnlySpan remaining = buffer; + byte[] packetBuffer = null; + try { - int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); - int packetLength = TdsEnums.HEADER_LEN + dataLength; - - if (packetBuffer == null) + while (remaining.Length > 0) { - packetBuffer = ArrayPool.Shared.Rent(packetLength); - } - else if (packetBuffer.Length < packetLength) - { - ArrayPool.Shared.Return(packetBuffer, clearArray: true); - packetBuffer = ArrayPool.Shared.Rent(packetLength); - } + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); + int packetLength = TdsEnums.HEADER_LEN + dataLength; - SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } - Span data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength); - remaining.Slice(0, dataLength).CopyTo(data); + SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); - _stream.Write(packetBuffer.AsSpan(0, packetLength)); - _stream.Flush(); + Span data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength); + remaining.Slice(0, dataLength).CopyTo(data); - remaining = remaining.Slice(dataLength); + _stream.Write(packetBuffer.AsSpan(0, packetLength)); + _stream.Flush(); + + remaining = remaining.Slice(dataLength); + } } - } - finally - { - if (packetBuffer != null) + finally { - ArrayPool.Shared.Return(packetBuffer, clearArray: true); + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } } } } @@ -247,48 +253,51 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella return; } - ReadOnlyMemory remaining = buffer; - byte[] packetBuffer = null; - try + using (SNIEventScope.Create(" writing encapsulated bytes")) { - while (remaining.Length > 0) + ReadOnlyMemory remaining = buffer; + byte[] packetBuffer = null; + try { - int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); - int packetLength = TdsEnums.HEADER_LEN + dataLength; - - if (packetBuffer == null) + while (remaining.Length > 0) { - packetBuffer = ArrayPool.Shared.Rent(packetLength); - } - else if (packetBuffer.Length < packetLength) - { - ArrayPool.Shared.Return(packetBuffer, clearArray: true); - packetBuffer = ArrayPool.Shared.Rent(packetLength); - } + int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); + int packetLength = TdsEnums.HEADER_LEN + dataLength; - SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); + if (packetBuffer == null) + { + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } + else if (packetBuffer.Length < packetLength) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + packetBuffer = ArrayPool.Shared.Rent(packetLength); + } - remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength)); + SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); + + remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength)); - { - ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory(packetBuffer, 0, packetLength), cancellationToken); - if (!packetWriteValueTask.IsCompletedSuccessfully) { - await packetWriteValueTask.ConfigureAwait(false); + ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory(packetBuffer, 0, packetLength), cancellationToken); + if (!packetWriteValueTask.IsCompletedSuccessfully) + { + await packetWriteValueTask.ConfigureAwait(false); + } } - } - await _stream.FlushAsync().ConfigureAwait(false); + await _stream.FlushAsync().ConfigureAwait(false); - remaining = remaining.Slice(dataLength); + remaining = remaining.Slice(dataLength); + } } - } - finally - { - if (packetBuffer != null) + finally { - ArrayPool.Shared.Return(packetBuffer, clearArray: true); + if (packetBuffer != null) + { + ArrayPool.Shared.Return(packetBuffer, clearArray: true); + } } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs index 0a5b3ffee3..7798b25d06 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs @@ -14,7 +14,12 @@ internal sealed partial class SslOverTdsStream : Stream { public override int Read(byte[] buffer, int offset, int count) { - if (_encapsulate) + if (!_encapsulate) + { + return _stream.Read(buffer, offset, count); + } + + using (SNIEventScope.Create(" reading encapsulated bytes")) { if (_packetBytes > 0) { @@ -60,15 +65,17 @@ public override int Read(byte[] buffer, int offset, int count) return packetBytesRead; } } - else - { - return _stream.Read(buffer, offset, count); - } + } public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (_encapsulate) + if (!_encapsulate) + { + return await _stream.ReadAsync(buffer, offset, count, cancellationToken); + } + + using (SNIEventScope.Create(" reading encapsulated bytes")) { if (_packetBytes > 0) { @@ -114,17 +121,20 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, return packetBytesRead; } } - else - { - return await _stream.ReadAsync(buffer, offset, count, cancellationToken); - } } public override void Write(byte[] buffer, int offset, int count) { // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After // negotiation, the underlying socket only sees SSL frames. - if (_encapsulate) + if (!_encapsulate) + { + _stream.Write(buffer, offset, count); + _stream.Flush(); + return; + } + + using (SNIEventScope.Create(" writing encapsulated bytes")) { int remainingBytes = count; int dataOffset = offset; @@ -159,16 +169,22 @@ public override void Write(byte[] buffer, int offset, int count) ArrayPool.Shared.Return(packetBuffer, clearArray: true); } } - else - { - _stream.Write(buffer, offset, count); - _stream.Flush(); - } } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (_encapsulate) + if (!_encapsulate) + { + await _stream.WriteAsync(buffer, offset, count).ConfigureAwait(false); + Task flushTask = _stream.FlushAsync(); + if (flushTask.Status == TaskStatus.RanToCompletion) + { + await flushTask.ConfigureAwait(false); + } + return; + } + + using (SNIEventScope.Create(" writing encapsulated bytes")) { int remainingBytes = count; int dataOffset = offset; @@ -203,16 +219,6 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc ArrayPool.Shared.Return(packetBuffer, clearArray: true); } } - else - { - await _stream.WriteAsync(buffer, offset, count).ConfigureAwait(false); - Task flushTask = _stream.FlushAsync(); - if (flushTask.Status == TaskStatus.RanToCompletion) - { - await flushTask.ConfigureAwait(false); - } - } } - } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs index 09ed5770fc..905349e819 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -36,7 +36,11 @@ public SslOverTdsStream(Stream stream) /// /// Finish SSL handshake. Stop encapsulating in TDS. /// - public void FinishHandshake() => _encapsulate = false; + public void FinishHandshake() + { + _encapsulate = false; + SqlClientEventSource.Log.SNITraceEvent(" switched from encapsulation to passthrough mode"); + } /// /// Set stream length. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index 93e544a713..9ce089bac8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Diagnostics.Tracing; using System.Threading; @@ -946,4 +947,24 @@ internal void SNIScopeLeave(long scopeId) } #endregion } + + internal readonly struct SNIEventScope : IDisposable + { + private readonly long _scopeID; + + public SNIEventScope(long scopeID) + { + _scopeID = scopeID; + } + + public void Dispose() + { + SqlClientEventSource.Log.SNIScopeLeaveEvent(_scopeID); + } + + public static SNIEventScope Create(string message) + { + return new SNIEventScope(SqlClientEventSource.Log.SNIScopeEnterEvent(message)); + } + } } From c8940ae1c6310af9dd0eadd4acb4c3800d748067 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra Date: Mon, 19 Oct 2020 14:05:03 -0700 Subject: [PATCH 6/6] Adjustments to recent changes. --- .../netcore/src/Microsoft.Data.SqlClient.csproj | 2 +- .../src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs | 2 +- .../src/Microsoft/Data/SqlClient/SqlClientEventSource.cs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 1766cefe46..5357354fb4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -264,13 +264,13 @@ - + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs index 905349e819..58384dfd58 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -39,7 +39,7 @@ public SslOverTdsStream(Stream stream) public void FinishHandshake() { _encapsulate = false; - SqlClientEventSource.Log.SNITraceEvent(" switched from encapsulation to passthrough mode"); + SqlClientEventSource.Log.TrySNITraceEvent(" switched from encapsulation to passthrough mode"); } /// diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index a0a17ed982..d7f0decbf5 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -1024,12 +1024,12 @@ public SNIEventScope(long scopeID) public void Dispose() { - SqlClientEventSource.Log.SNIScopeLeaveEvent(_scopeID); + SqlClientEventSource.Log.SNIScopeLeave(_scopeID); } public static SNIEventScope Create(string message) { - return new SNIEventScope(SqlClientEventSource.Log.SNIScopeEnterEvent(message)); + return new SNIEventScope(SqlClientEventSource.Log.SNIScopeEnter(message)); } } }