Skip to content

Commit

Permalink
Fix source generated formatters for records with string keys (#1798)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-bym committed Apr 13, 2024
1 parent b7be200 commit c07dc04
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ private bool CheckValidMessagePackFormatterAttribute(AttributeData formatterAttr
var constructorParameters = new List<MemberSerializationInfo>();
if (ctor != null)
{
var constructorLookupDictionary = stringMembers.ToLookup(x => x.Key, x => x, StringComparer.OrdinalIgnoreCase);
var constructorLookupDictionary = stringMembers.ToLookup(x => x.Value.Name, x => x, StringComparer.OrdinalIgnoreCase);
do
{
constructorParameters.Clear();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) All contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using VerifyCS = CSharpSourceGeneratorVerifier<MessagePack.SourceGenerator.MessagePackGenerator>;

public class CustomStringKeyFormatterTests
{
private readonly ITestOutputHelper testOutputHelper;

public CustomStringKeyFormatterTests(ITestOutputHelper testOutputHelper)
{
this.testOutputHelper = testOutputHelper;
}

[Fact]
public async Task RecordWithPrimaryConstructor()
{
string testSource = """
using System;
using System.Collections.Generic;
using MessagePack;

namespace TempProject
{
[MessagePackObject]
public record MyMessagePackObject([property: Key("p")] string PhoneNumber, [property: Key("c")] int Count);
}
""";
await VerifyCS.Test.RunDefaultAsync(this.testOutputHelper, testSource);
}

[Fact]
public async Task RecordWithWithInitOnlyProps()
{
string testSource = """
using System;
using System.Collections.Generic;
using MessagePack;

namespace TempProject
{
[MessagePackObject]
public class MyMessagePackObject
{
public MyMessagePackObject(string phoneNumber, int count)
{
PhoneNumber = phoneNumber;
Count = count;
}

[Key("p")]
public string PhoneNumber { get; set; }

[Key("c")]
public int Count { get; set; }
};
}
""";
await VerifyCS.Test.RunDefaultAsync(this.testOutputHelper, testSource);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// <auto-generated />

#pragma warning disable 618, 612, 414, 168, CS1591, SA1129, SA1309, SA1312, SA1403, SA1649

namespace MessagePack {

using MsgPack = global::MessagePack;

partial class GeneratedMessagePackResolver
{
private partial class TempProject {
internal sealed class MyMessagePackObjectFormatter : global::MessagePack.Formatters.IMessagePackFormatter<global::TempProject.MyMessagePackObject>
{
// p
private static global::System.ReadOnlySpan<byte> GetSpan_PhoneNumber() => new byte[1 + 1] { 161, 112 };
// c
private static global::System.ReadOnlySpan<byte> GetSpan_Count() => new byte[1 + 1] { 161, 99 };

public void Serialize(ref global::MessagePack.MessagePackWriter writer, global::TempProject.MyMessagePackObject value, global::MessagePack.MessagePackSerializerOptions options)
{
if (value is null)
{
writer.WriteNil();
return;
}

var formatterResolver = options.Resolver;
writer.WriteMapHeader(2);
writer.WriteRaw(GetSpan_PhoneNumber());
MsgPack::FormatterResolverExtensions.GetFormatterWithVerify<string>(formatterResolver).Serialize(ref writer, value.PhoneNumber, options);
writer.WriteRaw(GetSpan_Count());
writer.Write(value.Count);
}

public global::TempProject.MyMessagePackObject Deserialize(ref global::MessagePack.MessagePackReader reader, global::MessagePack.MessagePackSerializerOptions options)
{
if (reader.TryReadNil())
{
return null;
}

options.Security.DepthStep(ref reader);
var formatterResolver = options.Resolver;
var length = reader.ReadMapHeader();
var __PhoneNumber__ = default(string);
var __Count__ = default(int);

for (int i = 0; i < length; i++)
{
var stringKey = global::MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref reader);
switch (stringKey.Length)
{
default:
FAIL:
reader.Skip();
continue;
case 1:
switch (global::MessagePack.Internal.AutomataKeyGen.GetKey(ref stringKey))
{
default: goto FAIL;
case 112UL:
__PhoneNumber__ = MsgPack::FormatterResolverExtensions.GetFormatterWithVerify<string>(formatterResolver).Deserialize(ref reader, options);
continue;
case 99UL:
__Count__ = reader.ReadInt32();
continue;
}

}
}

var ____result = new global::TempProject.MyMessagePackObject(__PhoneNumber__, __Count__);
reader.Depth--;
return ____result;
}
}
}}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// <auto-generated />

#pragma warning disable 618, 612, 414, 168, CS1591, SA1129, SA1309, SA1312, SA1403, SA1649

using MsgPack = global::MessagePack;

[assembly: MsgPack::Internal.GeneratedAssemblyMessagePackResolverAttribute(typeof(MessagePack.GeneratedMessagePackResolver), 3, 0)]

namespace MessagePack {

/// <summary>A MessagePack resolver that uses generated formatters for types in this assembly.</summary>
partial class GeneratedMessagePackResolver : MsgPack::IFormatterResolver
{
/// <summary>An instance of this resolver that only returns formatters specifically generated for types in this assembly.</summary>
public static readonly MsgPack::IFormatterResolver Instance = new GeneratedMessagePackResolver();

private GeneratedMessagePackResolver()
{
}

public MsgPack::Formatters.IMessagePackFormatter<T> GetFormatter<T>()
{
return FormatterCache<T>.Formatter;
}

private static class FormatterCache<T>
{
internal static readonly MsgPack::Formatters.IMessagePackFormatter<T> Formatter;

static FormatterCache()
{
var f = GeneratedMessagePackResolverGetFormatterHelper.GetFormatter(typeof(T));
if (f != null)
{
Formatter = (MsgPack::Formatters.IMessagePackFormatter<T>)f;
}
}
}

private static class GeneratedMessagePackResolverGetFormatterHelper
{
private static readonly global::System.Collections.Generic.Dictionary<global::System.Type, int> closedTypeLookup = new(1)
{
{ typeof(global::TempProject.MyMessagePackObject), 0 },
};

internal static object GetFormatter(global::System.Type t)
{
if (closedTypeLookup.TryGetValue(t, out int closedKey))
{
return closedKey switch
{
0 => new TempProject.MyMessagePackObjectFormatter(),
_ => null, // unreachable
};
}

return null;
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// <auto-generated />

#pragma warning disable 618, 612, 414, 168, CS1591, SA1129, SA1309, SA1312, SA1403, SA1649

namespace MessagePack {

using MsgPack = global::MessagePack;

partial class GeneratedMessagePackResolver
{
private partial class TempProject {
internal sealed class MyMessagePackObjectFormatter : global::MessagePack.Formatters.IMessagePackFormatter<global::TempProject.MyMessagePackObject>
{
// p
private static global::System.ReadOnlySpan<byte> GetSpan_PhoneNumber() => new byte[1 + 1] { 161, 112 };
// c
private static global::System.ReadOnlySpan<byte> GetSpan_Count() => new byte[1 + 1] { 161, 99 };

public void Serialize(ref global::MessagePack.MessagePackWriter writer, global::TempProject.MyMessagePackObject value, global::MessagePack.MessagePackSerializerOptions options)
{
if (value is null)
{
writer.WriteNil();
return;
}

var formatterResolver = options.Resolver;
writer.WriteMapHeader(2);
writer.WriteRaw(GetSpan_PhoneNumber());
MsgPack::FormatterResolverExtensions.GetFormatterWithVerify<string>(formatterResolver).Serialize(ref writer, value.PhoneNumber, options);
writer.WriteRaw(GetSpan_Count());
writer.Write(value.Count);
}

public global::TempProject.MyMessagePackObject Deserialize(ref global::MessagePack.MessagePackReader reader, global::MessagePack.MessagePackSerializerOptions options)
{
if (reader.TryReadNil())
{
return null;
}

options.Security.DepthStep(ref reader);
var formatterResolver = options.Resolver;
var length = reader.ReadMapHeader();
var __PhoneNumber__ = default(string);
var __Count__ = default(int);

for (int i = 0; i < length; i++)
{
var stringKey = global::MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref reader);
switch (stringKey.Length)
{
default:
FAIL:
reader.Skip();
continue;
case 1:
switch (global::MessagePack.Internal.AutomataKeyGen.GetKey(ref stringKey))
{
default: goto FAIL;
case 112UL:
__PhoneNumber__ = MsgPack::FormatterResolverExtensions.GetFormatterWithVerify<string>(formatterResolver).Deserialize(ref reader, options);
continue;
case 99UL:
__Count__ = reader.ReadInt32();
continue;
}

}
}

var ____result = new global::TempProject.MyMessagePackObject(__PhoneNumber__, __Count__);
reader.Depth--;
return ____result;
}
}
}}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// <auto-generated />

#pragma warning disable 618, 612, 414, 168, CS1591, SA1129, SA1309, SA1312, SA1403, SA1649

using MsgPack = global::MessagePack;

[assembly: MsgPack::Internal.GeneratedAssemblyMessagePackResolverAttribute(typeof(MessagePack.GeneratedMessagePackResolver), 3, 0)]

namespace MessagePack {

/// <summary>A MessagePack resolver that uses generated formatters for types in this assembly.</summary>
partial class GeneratedMessagePackResolver : MsgPack::IFormatterResolver
{
/// <summary>An instance of this resolver that only returns formatters specifically generated for types in this assembly.</summary>
public static readonly MsgPack::IFormatterResolver Instance = new GeneratedMessagePackResolver();

private GeneratedMessagePackResolver()
{
}

public MsgPack::Formatters.IMessagePackFormatter<T> GetFormatter<T>()
{
return FormatterCache<T>.Formatter;
}

private static class FormatterCache<T>
{
internal static readonly MsgPack::Formatters.IMessagePackFormatter<T> Formatter;

static FormatterCache()
{
var f = GeneratedMessagePackResolverGetFormatterHelper.GetFormatter(typeof(T));
if (f != null)
{
Formatter = (MsgPack::Formatters.IMessagePackFormatter<T>)f;
}
}
}

private static class GeneratedMessagePackResolverGetFormatterHelper
{
private static readonly global::System.Collections.Generic.Dictionary<global::System.Type, int> closedTypeLookup = new(1)
{
{ typeof(global::TempProject.MyMessagePackObject), 0 },
};

internal static object GetFormatter(global::System.Type t)
{
if (closedTypeLookup.TryGetValue(t, out int closedKey))
{
return closedKey switch
{
0 => new TempProject.MyMessagePackObjectFormatter(),
_ => null, // unreachable
};
}

return null;
}
}
}

}

0 comments on commit c07dc04

Please sign in to comment.