diff --git a/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs b/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs index 8a9c3d0c685f..4a425f79aa84 100644 --- a/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs +++ b/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs @@ -386,7 +386,7 @@ public void MergingMessageWithZero(int? originValue, int? mergingValue, int? exp } [Test] - public void UnknownFieldInWrapper() + public void UnknownFieldInWrapperInt32FastPath() { var stream = new MemoryStream(); var output = new CodedOutputStream(stream); @@ -395,19 +395,96 @@ public void UnknownFieldInWrapper() var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint); output.WriteTag(wrapperTag); - output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte + // Wrapper message is just long enough - 6 bytes - to use the wrapper fast-path. + output.WriteLength(6); // unknownTag + value 5 + valueType, each 1 byte, + value 65536, 3 bytes output.WriteTag(unknownTag); output.WriteInt32((int) valueTag); // Sneakily "pretend" it's a tag when it's really a value output.WriteTag(valueTag); + output.WriteInt32(65536); + + output.Flush(); + Assert.AreEqual(8, stream.Length); // tag (1 byte) + length (1 byte) + message (6 bytes) + stream.Position = 0; + + var message = TestWellKnownTypes.Parser.ParseFrom(stream); + Assert.AreEqual(65536, message.Int32Field); + } + + [Test] + public void UnknownFieldInWrapperInt32SlowPath() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int32FieldFieldNumber, WireFormat.WireType.LengthDelimited); + var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint); + var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint); + + output.WriteTag(wrapperTag); + // Wrapper message is too short to be used on the wrapper fast-path. + output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte + output.WriteTag(unknownTag); + output.WriteInt32((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value + output.WriteTag(valueTag); output.WriteInt32(6); output.Flush(); + Assert.Less(stream.Length, 8); // tag (1 byte) + length (1 byte) + message stream.Position = 0; var message = TestWellKnownTypes.Parser.ParseFrom(stream); Assert.AreEqual(6, message.Int32Field); } + [Test] + public void UnknownFieldInWrapperInt64FastPath() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited); + var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint); + var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint); + + output.WriteTag(wrapperTag); + // Wrapper message is just long enough - 10 bytes - to use the wrapper fast-path. + output.WriteLength(11); // unknownTag + value 5 + valueType, each 1 byte, + value 0xfffffffffffff, 8 bytes + output.WriteTag(unknownTag); + output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value + output.WriteTag(valueTag); + output.WriteInt64(0xfffffffffffffL); + + output.Flush(); + Assert.AreEqual(13, stream.Length); // tag (1 byte) + length (1 byte) + message (11 bytes) + stream.Position = 0; + + var message = TestWellKnownTypes.Parser.ParseFrom(stream); + Assert.AreEqual(0xfffffffffffffL, message.Int64Field); + } + + [Test] + public void UnknownFieldInWrapperInt64SlowPath() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited); + var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint); + var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint); + + output.WriteTag(wrapperTag); + // Wrapper message is too short to be used on the wrapper fast-path. + output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte + output.WriteTag(unknownTag); + output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value + output.WriteTag(valueTag); + output.WriteInt64(6); + + output.Flush(); + Assert.Less(stream.Length, 12); // tag (1 byte) + length (1 byte) + message + stream.Position = 0; + + var message = TestWellKnownTypes.Parser.ParseFrom(stream); + Assert.AreEqual(6L, message.Int64Field); + } + [Test] public void ClearWithReflection() { diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index 4d13f88ab95d..44934f341b25 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -481,7 +481,33 @@ internal void SkipGroup(uint startGroupTag) /// public double ReadDouble() { - return BitConverter.Int64BitsToDouble((long) ReadRawLittleEndian64()); + if (bufferPos + 8 <= bufferSize) + { + if (BitConverter.IsLittleEndian) + { + var result = BitConverter.ToDouble(buffer, bufferPos); + bufferPos += 8; + return result; + } + else + { + var bytes = new byte[8]; + bytes[0] = buffer[bufferPos + 7]; + bytes[1] = buffer[bufferPos + 6]; + bytes[2] = buffer[bufferPos + 5]; + bytes[3] = buffer[bufferPos + 4]; + bytes[4] = buffer[bufferPos + 3]; + bytes[5] = buffer[bufferPos + 2]; + bytes[6] = buffer[bufferPos + 1]; + bytes[7] = buffer[bufferPos]; + bufferPos += 8; + return BitConverter.ToDouble(bytes, 0); + } + } + else + { + return BitConverter.Int64BitsToDouble((long)ReadRawLittleEndian64()); + } } /// @@ -711,7 +737,260 @@ public bool MaybeConsumeTag(uint tag) return false; } - #endregion + internal static float? ReadFloatWrapperLittleEndian(CodedInputStream input) + { + // length:1 + tag:1 + value:4 = 6 bytes + if (input.bufferPos + 6 <= input.bufferSize) + { + // The entire wrapper message is already contained in `buffer`. + int length = input.buffer[input.bufferPos]; + if (length == 0) + { + input.bufferPos++; + return 0F; + } + // tag:1 + value:4 = length of 5 bytes + // field=1, type=32-bit = tag of 13 + if (length != 5 || input.buffer[input.bufferPos + 1] != 13) + { + return ReadFloatWrapperSlow(input); + } + var result = BitConverter.ToSingle(input.buffer, input.bufferPos + 2); + input.bufferPos += 6; + return result; + } + else + { + return ReadFloatWrapperSlow(input); + } + } + + internal static float? ReadFloatWrapperSlow(CodedInputStream input) + { + int length = input.ReadLength(); + if (length == 0) + { + return 0F; + } + int finalBufferPos = input.totalBytesRetired + input.bufferPos + length; + float result = 0F; + do + { + // field=1, type=32-bit = tag of 13 + if (input.ReadTag() == 13) + { + result = input.ReadFloat(); + } + else + { + input.SkipLastField(); + } + } + while (input.totalBytesRetired + input.bufferPos < finalBufferPos); + return result; + } + + internal static double? ReadDoubleWrapperLittleEndian(CodedInputStream input) + { + // length:1 + tag:1 + value:8 = 10 bytes + if (input.bufferPos + 10 <= input.bufferSize) + { + // The entire wrapper message is already contained in `buffer`. + int length = input.buffer[input.bufferPos]; + if (length == 0) + { + input.bufferPos++; + return 0D; + } + // tag:1 + value:8 = length of 9 bytes + // field=1, type=64-bit = tag of 9 + if (length != 9 || input.buffer[input.bufferPos + 1] != 9) + { + return ReadDoubleWrapperSlow(input); + } + var result = BitConverter.ToDouble(input.buffer, input.bufferPos + 2); + input.bufferPos += 10; + return result; + } + else + { + return ReadDoubleWrapperSlow(input); + } + } + + internal static double? ReadDoubleWrapperSlow(CodedInputStream input) + { + int length = input.ReadLength(); + if (length == 0) + { + return 0D; + } + int finalBufferPos = input.totalBytesRetired + input.bufferPos + length; + double result = 0D; + do + { + // field=1, type=64-bit = tag of 9 + if (input.ReadTag() == 9) + { + result = input.ReadDouble(); + } + else + { + input.SkipLastField(); + } + } + while (input.totalBytesRetired + input.bufferPos < finalBufferPos); + return result; + } + + internal static bool? ReadBoolWrapper(CodedInputStream input) + { + return ReadUInt32Wrapper(input) != 0; + } + + internal static uint? ReadUInt32Wrapper(CodedInputStream input) + { + // length:1 + tag:1 + value:5(varint32-max) = 7 bytes + if (input.bufferPos + 7 <= input.bufferSize) + { + // The entire wrapper message is already contained in `buffer`. + int pos0 = input.bufferPos; + int length = input.buffer[input.bufferPos++]; + if (length == 0) + { + return 0; + } + // Length will always fit in a single byte. + if (length >= 128) + { + input.bufferPos = pos0; + return ReadUInt32WrapperSlow(input); + } + int finalBufferPos = input.bufferPos + length; + // field=1, type=varint = tag of 8 + if (input.buffer[input.bufferPos++] != 8) + { + input.bufferPos = pos0; + return ReadUInt32WrapperSlow(input); + } + var result = input.ReadUInt32(); + // Verify this message only contained a single field. + if (input.bufferPos != finalBufferPos) + { + input.bufferPos = pos0; + return ReadUInt32WrapperSlow(input); + } + return result; + } + else + { + return ReadUInt32WrapperSlow(input); + } + } + + private static uint? ReadUInt32WrapperSlow(CodedInputStream input) + { + int length = input.ReadLength(); + if (length == 0) + { + return 0; + } + int finalBufferPos = input.totalBytesRetired + input.bufferPos + length; + uint result = 0; + do + { + // field=1, type=varint = tag of 8 + if (input.ReadTag() == 8) + { + result = input.ReadUInt32(); + } + else + { + input.SkipLastField(); + } + } + while (input.totalBytesRetired + input.bufferPos < finalBufferPos); + return result; + } + + internal static int? ReadInt32Wrapper(CodedInputStream input) + { + return (int?)ReadUInt32Wrapper(input); + } + + internal static ulong? ReadUInt64Wrapper(CodedInputStream input) + { + // field=1, type=varint = tag of 8 + const int expectedTag = 8; + // length:1 + tag:1 + value:10(varint64-max) = 12 bytes + if (input.bufferPos + 12 <= input.bufferSize) + { + // The entire wrapper message is already contained in `buffer`. + int pos0 = input.bufferPos; + int length = input.buffer[input.bufferPos++]; + if (length == 0) + { + return 0L; + } + // Length will always fit in a single byte. + if (length >= 128) + { + input.bufferPos = pos0; + return ReadUInt64WrapperSlow(input); + } + int finalBufferPos = input.bufferPos + length; + if (input.buffer[input.bufferPos++] != expectedTag) + { + input.bufferPos = pos0; + return ReadUInt64WrapperSlow(input); + } + var result = input.ReadUInt64(); + // Verify this message only contained a single field. + if (input.bufferPos != finalBufferPos) + { + input.bufferPos = pos0; + return ReadUInt64WrapperSlow(input); + } + return result; + } + else + { + return ReadUInt64WrapperSlow(input); + } + } + + internal static ulong? ReadUInt64WrapperSlow(CodedInputStream input) + { + // field=1, type=varint = tag of 8 + const int expectedTag = 8; + int length = input.ReadLength(); + if (length == 0) + { + return 0L; + } + int finalBufferPos = input.totalBytesRetired + input.bufferPos + length; + ulong result = 0L; + do + { + if (input.ReadTag() == expectedTag) + { + result = input.ReadUInt64(); + } + else + { + input.SkipLastField(); + } + } + while (input.totalBytesRetired + input.bufferPos < finalBufferPos); + return result; + } + + internal static long? ReadInt64Wrapper(CodedInputStream input) + { + return (long?)ReadUInt64Wrapper(input); + } + +#endregion #region Underlying reading primitives @@ -876,17 +1155,42 @@ internal static uint ReadRawVarint32(Stream input) /// internal ulong ReadRawVarint64() { - int shift = 0; - ulong result = 0; - while (shift < 64) + if (bufferPos + 10 <= bufferSize) { - byte b = ReadRawByte(); - result |= (ulong) (b & 0x7F) << shift; - if ((b & 0x80) == 0) + ulong result = buffer[bufferPos++]; + if (result < 128) { return result; } - shift += 7; + result &= 0x7f; + int shift = 7; + do + { + byte b = buffer[bufferPos++]; + result |= (ulong)(b & 0x7F) << shift; + if (b < 0x80) + { + return result; + } + shift += 7; + } + while (shift < 64); + } + else + { + int shift = 0; + ulong result = 0; + do + { + byte b = ReadRawByte(); + result |= (ulong)(b & 0x7F) << shift; + if (b < 0x80) + { + return result; + } + shift += 7; + } + while (shift < 64); } throw InvalidProtocolBufferException.MalformedVarint(); } @@ -896,11 +1200,32 @@ internal ulong ReadRawVarint64() /// internal uint ReadRawLittleEndian32() { - uint b1 = ReadRawByte(); - uint b2 = ReadRawByte(); - uint b3 = ReadRawByte(); - uint b4 = ReadRawByte(); - return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24); + if (bufferPos + 4 <= bufferSize) + { + if (BitConverter.IsLittleEndian) + { + var result = BitConverter.ToUInt32(buffer, bufferPos); + bufferPos += 4; + return result; + } + else + { + uint b1 = buffer[bufferPos]; + uint b2 = buffer[bufferPos + 1]; + uint b3 = buffer[bufferPos + 2]; + uint b4 = buffer[bufferPos + 3]; + bufferPos += 4; + return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24); + } + } + else + { + uint b1 = ReadRawByte(); + uint b2 = ReadRawByte(); + uint b3 = ReadRawByte(); + uint b4 = ReadRawByte(); + return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24); + } } /// @@ -908,16 +1233,42 @@ internal uint ReadRawLittleEndian32() /// internal ulong ReadRawLittleEndian64() { - ulong b1 = ReadRawByte(); - ulong b2 = ReadRawByte(); - ulong b3 = ReadRawByte(); - ulong b4 = ReadRawByte(); - ulong b5 = ReadRawByte(); - ulong b6 = ReadRawByte(); - ulong b7 = ReadRawByte(); - ulong b8 = ReadRawByte(); - return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24) - | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56); + if (bufferPos + 8 <= bufferSize) + { + if (BitConverter.IsLittleEndian) + { + var result = BitConverter.ToUInt64(buffer, bufferPos); + bufferPos += 8; + return result; + } + else + { + ulong b1 = buffer[bufferPos]; + ulong b2 = buffer[bufferPos + 1]; + ulong b3 = buffer[bufferPos + 2]; + ulong b4 = buffer[bufferPos + 3]; + ulong b5 = buffer[bufferPos + 4]; + ulong b6 = buffer[bufferPos + 5]; + ulong b7 = buffer[bufferPos + 6]; + ulong b8 = buffer[bufferPos + 7]; + bufferPos += 8; + return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24) + | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56); + } + } + else + { + ulong b1 = ReadRawByte(); + ulong b2 = ReadRawByte(); + ulong b3 = ReadRawByte(); + ulong b4 = ReadRawByte(); + ulong b5 = ReadRawByte(); + ulong b6 = ReadRawByte(); + ulong b7 = ReadRawByte(); + ulong b8 = ReadRawByte(); + return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24) + | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56); + } } /// @@ -1301,6 +1652,6 @@ private void SkipImpl(int amountToSkip) } } } - #endregion +#endregion } } \ No newline at end of file diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs index ebd00b702206..1971261649a0 100644 --- a/csharp/src/Google.Protobuf/FieldCodec.cs +++ b/csharp/src/Google.Protobuf/FieldCodec.cs @@ -507,7 +507,7 @@ public static FieldCodec ForEnum(uint tag, Func toInt32, Func(); return new FieldCodec( - input => WrapperCodecs.Read(input, nestedCodec), + WrapperCodecs.GetReader(), (output, value) => WrapperCodecs.Write(output, value.Value, nestedCodec), (CodedInputStream i, ref T? v) => v = WrapperCodecs.Read(i, nestedCodec), (ref T? v, T? v2) => { if (v2.HasValue) { v = v2; } return v.HasValue; }, @@ -539,6 +539,25 @@ private static class WrapperCodecs { typeof(ByteString), ForBytes(WireFormat.MakeTag(WrappersReflection.WrapperValueFieldNumber, WireFormat.WireType.LengthDelimited)) } }; + private static readonly Dictionary Readers = new Dictionary + { + // TODO: Provide more optimized readers. + { typeof(bool), (Func)CodedInputStream.ReadBoolWrapper }, + { typeof(int), (Func)CodedInputStream.ReadInt32Wrapper }, + { typeof(long), (Func)CodedInputStream.ReadInt64Wrapper }, + { typeof(uint), (Func)CodedInputStream.ReadUInt32Wrapper }, + { typeof(ulong), (Func)CodedInputStream.ReadUInt64Wrapper }, + { typeof(float), BitConverter.IsLittleEndian ? + (Func)CodedInputStream.ReadFloatWrapperLittleEndian : + (Func)CodedInputStream.ReadFloatWrapperSlow }, + { typeof(double), BitConverter.IsLittleEndian ? + (Func)CodedInputStream.ReadDoubleWrapperLittleEndian : + (Func)CodedInputStream.ReadDoubleWrapperSlow }, + // `string` and `ByteString` less performance-sensitive. Do not implement for now. + { typeof(string), null }, + { typeof(ByteString), null }, + }; + /// /// Returns a field codec which effectively wraps a value of type T in a message. /// @@ -553,6 +572,23 @@ internal static FieldCodec GetCodec() return (FieldCodec) value; } + internal static Func GetReader() where T : struct + { + object value; + if (!Readers.TryGetValue(typeof(T), out value)) + { + throw new InvalidOperationException("Invalid type argument requested for wrapper reader: " + typeof(T)); + } + if (value == null) + { + // Return default unoptimized reader for the wrapper type. + var nestedCoded = GetCodec(); + return input => Read(input, nestedCoded); + } + // Return optimized read for the wrapper type. + return (Func)value; + } + internal static T Read(CodedInputStream input, FieldCodec codec) { int length = input.ReadLength();