Skip to content

Commit

Permalink
[mono][wasm] Fix function signature mismatch in m2n invoke (#101106)
Browse files Browse the repository at this point in the history
 Fix signature mismatch
  • Loading branch information
mkhamoyan committed Apr 23, 2024
1 parent 7d91bf5 commit 40bc2d8
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 64 deletions.
10 changes: 3 additions & 7 deletions src/mono/browser/runtime/runtime.c
Expand Up @@ -65,6 +65,7 @@
int mono_wasm_enable_gc = 1;

/* Missing from public headers */
char *mono_fixup_symbol_name (char *key);
void mono_icall_table_init (void);
void mono_wasm_enable_debugging (int);
void mono_ee_interp_init (const char *opts);
Expand Down Expand Up @@ -213,13 +214,8 @@ get_native_to_interp (MonoMethod *method, void *extra_arg)

assert (strlen (name) < 100);
snprintf (key, sizeof(key), "%s_%s_%s", name, class_name, method_name);
len = strlen (key);
for (int i = 0; i < len; ++i) {
if (key [i] == '.')
key [i] = '_';
}

addr = wasm_dl_get_native_to_interp (key, extra_arg);
char* fixedName = mono_fixup_symbol_name(key);
addr = wasm_dl_get_native_to_interp (fixedName, extra_arg);
MONO_EXIT_GC_UNSAFE;
return addr;
}
Expand Down
29 changes: 29 additions & 0 deletions src/mono/mono/metadata/native-library.c
Expand Up @@ -1222,3 +1222,32 @@ mono_loader_install_pinvoke_override (PInvokeOverrideFn override_fn)
{
pinvoke_override = override_fn;
}

// Keep synced with FixupSymbolName from src/tasks/Common/Utils.cs
char* mono_fixup_symbol_name (char *key) {
char* fixedName = malloc(256);
int sb_index = 0;
int len = (int)strlen (key);

for (int i = 0; i < len; ++i) {
unsigned char b = key[i];
if ((b >= '0' && b <= '9') ||
(b >= 'a' && b <= 'z') ||
(b >= 'A' && b <= 'Z') ||
(b == '_')) {
fixedName[sb_index++] = b;
}
else if (b == '.' || b == '-' || b == '+' || b == '<' || b == '>') {
fixedName[sb_index++] = '_';
}
else {
// Append the hexadecimal representation of b between underscores
sprintf(&fixedName[sb_index], "_%X_", b);
sb_index += 4; // Move the index after the appended hexadecimal characters
}
}

// Null-terminate the fixedName string
fixedName[sb_index] = '\0';
return fixedName;
}
3 changes: 3 additions & 0 deletions src/mono/mono/metadata/native-library.h
Expand Up @@ -35,4 +35,7 @@ mono_lookup_pinvoke_qcall_internal (const char *name);
void
mono_loader_install_pinvoke_override (PInvokeOverrideFn override_fn);

char *
mono_fixup_symbol_name (char *key);

#endif
18 changes: 12 additions & 6 deletions src/mono/mono/mini/aot-compiler.c
Expand Up @@ -51,6 +51,7 @@
#include <mono/metadata/mempool-internals.h>
#include <mono/metadata/mono-basic-block.h>
#include <mono/metadata/mono-endian.h>
#include <mono/metadata/native-library.h>
#include <mono/metadata/threads-types.h>
#include <mono/metadata/custom-attrs-internals.h>
#include <mono/utils/mono-logger-internals.h>
Expand Down Expand Up @@ -12407,22 +12408,24 @@ emit_file_info (MonoAotCompile *acfg)

if (acfg->aot_opts.static_link) {
char symbol [MAX_SYMBOL_SIZE];
char *p;

/*
* Emit a global symbol which can be passed by an embedding app to
* mono_aot_register_module (). The symbol points to a pointer to the file info
* structure.
*/
sprintf (symbol, "%smono_aot_module_%s_info", acfg->user_symbol_prefix, acfg->image->assembly->aname.name);

#ifdef TARGET_WASM
acfg->static_linking_symbol = g_strdup (mono_fixup_symbol_name(symbol));
#else
/* Get rid of characters which cannot occur in symbols */
p = symbol;
char *p = symbol;
for (p = symbol; *p; ++p) {
if (!(isalnum (*p) || *p == '_'))
*p = '_';
}
acfg->static_linking_symbol = g_strdup (symbol);
#endif
}

if (acfg->llvm)
Expand Down Expand Up @@ -14860,7 +14863,6 @@ aot_assembly (MonoAssembly *ass, guint32 jit_opts, MonoAotOptions *aot_options)
{
MonoImage *image = ass->image;
MonoAotCompile *acfg;
char *p;
int res;
TV_DECLARE (atv);
TV_DECLARE (btv);
Expand Down Expand Up @@ -15120,13 +15122,17 @@ aot_assembly (MonoAssembly *ass, guint32 jit_opts, MonoAotOptions *aot_options)
acfg->flags = (MonoAotFileFlags)(acfg->flags | MONO_AOT_FILE_FLAG_LLVM_ONLY);

acfg->assembly_name_sym = g_strdup (get_assembly_prefix (acfg->image));
/* Get rid of characters which cannot occur in symbols */
#ifdef TARGET_WASM
acfg->global_prefix = g_strdup_printf ("mono_aot_%s", g_strdup(mono_fixup_symbol_name (acfg->assembly_name_sym)));
#else
char *p;
/* Get rid of characters which cannot occur in symbols */
for (p = acfg->assembly_name_sym; *p; ++p) {
if (!(isalnum (*p) || *p == '_'))
*p = '_';
}

acfg->global_prefix = g_strdup_printf ("mono_aot_%s", acfg->assembly_name_sym);
#endif
acfg->plt_symbol = g_strdup_printf ("%s_plt", acfg->global_prefix);
acfg->got_symbol = g_strdup_printf ("%s_got", acfg->global_prefix);
if (acfg->llvm) {
Expand Down
11 changes: 8 additions & 3 deletions src/mono/mono/mini/mini-llvm.c
Expand Up @@ -12,6 +12,7 @@
#include <mono/metadata/debug-helpers.h>
#include <mono/metadata/debug-internals.h>
#include <mono/metadata/mempool-internals.h>
#include <mono/metadata/native-library.h>
#include <mono/metadata/environment.h>
#include <mono/metadata/object-internals.h>
#include <mono/metadata/abi-details.h>
Expand Down Expand Up @@ -14519,17 +14520,21 @@ emit_aot_file_info (MonoLLVMModule *module)
LLVMSetInitializer (info_var, LLVMConstNamedStruct (module->info_var_type, fields, nfields));

if (module->static_link) {
char *s, *p;
char *s;
LLVMValueRef var;

s = g_strdup_printf ("mono_aot_module_%s_info", module->assembly->aname.name);
#ifdef TARGET_WASM
var = LLVMAddGlobal (module->lmodule, pointer_type (LLVMInt8Type ()), g_strdup (mono_fixup_symbol_name(s)));
#else
/* Get rid of characters which cannot occur in symbols */
p = s;
char *p = s;
for (p = s; *p; ++p) {
if (!(isalnum (*p) || *p == '_'))
*p = '_';
*p = '_';
}
var = LLVMAddGlobal (module->lmodule, pointer_type (LLVMInt8Type ()), s);
#endif
g_free (s);
LLVMSetInitializer (var, LLVMConstBitCast (LLVMGetNamedGlobal (module->lmodule, "mono_aot_file_info"), pointer_type (LLVMInt8Type ())));
LLVMSetLinkage (var, LLVMExternalLinkage);
Expand Down
4 changes: 3 additions & 1 deletion src/mono/wasm/Wasm.Build.Tests/Common/TestUtils.cs
Expand Up @@ -91,7 +91,9 @@ public static void AssertEqual(object expected, object actual, string label)
$"[{label}]\n");
}

private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
private static readonly char[] s_charsToReplace = new[] { '.', '-', '+', '<', '>' };
// Keep synced with FixupSymbolName from src/tasks/Common/Utils.cs
// and with mono_fixup_symbol_name from src/mono/mono/metadata/native-library.c
public static string FixupSymbolName(string name)
{
UTF8Encoding utf8 = new();
Expand Down
26 changes: 26 additions & 0 deletions src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs
Expand Up @@ -884,5 +884,31 @@ public static unsafe int Main(string[] argv)
[BuildAndRun(host: RunHost.Chrome, aot: false)]
public void EnsureWasmAbiRulesAreFollowedInInterpreter(BuildArgs buildArgs, RunHost host, string id) =>
EnsureWasmAbiRulesAreFollowed(buildArgs, host, id);

[Theory]
[BuildAndRun(host: RunHost.Chrome, aot: false)]
public void UCOWithSpecialCharacters(BuildArgs buildArgs, RunHost host, string id)
{
var extraProperties = "<AllowUnsafeBlocks>true</AllowUnsafeBlocks>";
var extraItems = @"<NativeFileReference Include=""local.c"" />";

buildArgs = ExpandBuildArgs(buildArgs,
extraItems: extraItems,
extraProperties: extraProperties);

(string libraryDir, string output) = BuildProject(buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () =>
{
File.Copy(Path.Combine(BuildEnvironment.TestAssetsPath, "Wasm.Buid.Tests.Programs", "UnmanagedCallback.cs"), Path.Combine(_projectDir!, "Program.cs"));
File.Copy(Path.Combine(BuildEnvironment.TestAssetsPath, "native-libs", "local.c"), Path.Combine(_projectDir!, "local.c"));
},
Publish: true,
DotnetWasmFromRuntimePack: false));

var runOutput = RunAndTestWasmApp(buildArgs, buildDir: _projectDir, expectedExitCode: 42, host: host, id: id);
Assert.Contains("ManagedFunc returned 42", runOutput);
}
}
}
@@ -0,0 +1,28 @@
using System;
using System.Runtime.InteropServices;

public unsafe partial class Test
{
public unsafe static int Main(string[] args)
{
((IntPtr)(delegate* unmanaged<int,int>)&Interop.Managed8\u4F60Func).ToString();

Console.WriteLine($"main: {args.Length}");
Interop.UnmanagedFunc();
return 42;
}
}

file partial class Interop
{
[UnmanagedCallersOnly(EntryPoint = "ManagedFunc")]
public static int Managed8\u4F60Func(int number)
{
// called from UnmanagedFunc
Console.WriteLine($"Managed8\u4F60Func({number}) -> 42");
return 42;
}

[DllImport("local", EntryPoint = "UnmanagedFunc")]
public static extern void UnmanagedFunc(); // calls ManagedFunc
}
10 changes: 10 additions & 0 deletions src/mono/wasm/testassets/native-libs/local.c
@@ -0,0 +1,10 @@
#include <stdio.h>
int ManagedFunc(int number);

void UnmanagedFunc()
{
int ret = 0;
printf("UnmanagedFunc calling ManagedFunc\n");
ret = ManagedFunc(123);
printf("ManagedFunc returned %d\n", ret);
}
21 changes: 1 addition & 20 deletions src/tasks/AotCompilerTask/MonoAOTCompiler.cs
Expand Up @@ -1237,26 +1237,7 @@ private string FixupSymbolName(string name)
if (_symbolNameFixups.TryGetValue(name, out string? fixedName))
return fixedName;

UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else
{
sb.Append('_');
}
}

fixedName = sb.ToString();
fixedName = Utils.FixupSymbolName(name);
_symbolNameFixups[name] = fixedName;
return fixedName;
}
Expand Down
32 changes: 32 additions & 0 deletions src/tasks/Common/Utils.cs
Expand Up @@ -11,6 +11,7 @@
using System.Reflection.Metadata;
using System.Security.Cryptography;
using System.Text;
using System.Linq;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;

Expand All @@ -33,6 +34,8 @@ public enum HashEncodingType

private static readonly object s_SyncObj = new object();

private static readonly char[] s_charsToReplace = new[] { '.', '-', '+', '<', '>' };

public static string GetEmbeddedResource(string file)
{
using Stream stream = typeof(Utils).Assembly
Expand Down Expand Up @@ -411,4 +414,33 @@ private static bool IsManagedAssembly(PEReader peReader)
return false;
}
}

// Keep synced with mono_fixup_symbol_name from src/mono/mono/metadata/native-library.c
public static string FixupSymbolName(string name)
{
UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

return sb.ToString();
}
}
27 changes: 1 addition & 26 deletions src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs
Expand Up @@ -36,8 +36,6 @@ public class ManagedToNativeGenerator : Task
[Output]
public string[]? FileWrites { get; private set; }

private static readonly char[] s_charsToReplace = new[] { '.', '-', '+', '<', '>' };

public override bool Execute()
{
if (Assemblies!.Length == 0)
Expand Down Expand Up @@ -108,30 +106,7 @@ string FixupSymbolName(string name)
if (_symbolNameFixups.TryGetValue(name, out string? fixedName))
return fixedName;

UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

fixedName = sb.ToString();
fixedName = Utils.FixupSymbolName(name);
_symbolNameFixups[name] = fixedName;
return fixedName;
}
Expand Down
2 changes: 1 addition & 1 deletion src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
Expand Up @@ -304,7 +304,7 @@ private string DelegateKey(PInvokeCallback export)
// it needs to match the key generated in get_native_to_interp
var method = export.Method;
string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!;
return $"\"{module_symbol}_{method.DeclaringType.Name}_{method.Name}\"".Replace('.', '_');
return $"\"{_fixupSymbolName($"{module_symbol}_{method.DeclaringType.Name}_{method.Name}")}\"";
}

#pragma warning disable SYSLIB1045 // framework doesn't support GeneratedRegexAttribute
Expand Down

0 comments on commit 40bc2d8

Please sign in to comment.