From 3b32c2028bb9abaf27752aa0763b1699dd380ec7 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Thu, 18 Aug 2022 19:37:11 +0900 Subject: [PATCH] Externalize compilation cache by compilers (#747) This adds the experimental support of the file system compilation cache. Notably, experimental.WithCompilationCacheDirName allows users to configure where the compiler writes the cache into. Versioning/validation of binary compatibility has been done via the release tag (which will be created from the end of this month). More specifically, the cache file starts with a header with the hardcoded wazero version. Fixes #618 Signed-off-by: Takeshi Yoneda Co-authored-by: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com> --- experimental/compilation_cache.go | 31 ++ internal/compilationcache/compilationcache.go | 42 ++ internal/compilationcache/file_cache.go | 99 +++++ internal/compilationcache/file_cache_test.go | 135 +++++++ internal/engine/compiler/engine.go | 44 +-- internal/engine/compiler/engine_cache.go | 188 +++++++++ internal/engine/compiler/engine_cache_test.go | 370 ++++++++++++++++++ internal/engine/compiler/engine_test.go | 23 -- internal/engine/compiler/impl_amd64.go | 6 +- internal/engine/compiler/impl_arm64.go | 6 +- internal/integration_test/bench/bench_test.go | 32 +- internal/platform/buf_writer.go | 20 + internal/platform/mmap.go | 23 +- internal/platform/mmap_test.go | 15 +- internal/platform/mmap_windows.go | 34 +- internal/platform/platform.go | 9 +- internal/version/version.go | 4 + runtime.go | 4 + runtime_test.go | 14 + version.go | 5 + 20 files changed, 1015 insertions(+), 89 deletions(-) create mode 100644 experimental/compilation_cache.go create mode 100644 internal/compilationcache/compilationcache.go create mode 100644 internal/compilationcache/file_cache.go create mode 100644 internal/compilationcache/file_cache_test.go create mode 100644 internal/engine/compiler/engine_cache.go create mode 100644 internal/engine/compiler/engine_cache_test.go create mode 100644 internal/platform/buf_writer.go create mode 100644 internal/version/version.go create mode 100644 version.go diff --git a/experimental/compilation_cache.go b/experimental/compilation_cache.go new file mode 100644 index 0000000000..9a2bc4a9b3 --- /dev/null +++ b/experimental/compilation_cache.go @@ -0,0 +1,31 @@ +package experimental + +import ( + "context" + + "github.com/tetratelabs/wazero/internal/compilationcache" +) + +// WithCompilationCacheDirName configures the destination directory of the compilation cache. +// Regardless of the usage of this, the compiled functions are cached in memory, but its lifetime is +// bound to the lifetime of wazero.Runtime or wazero.CompiledModule. +// +// With the given non-empty directory, wazero persists the cache into the directory and that cache +// will be used as long as the running wazero version match the version of compilation wazero. +// +// A cache is only valid for use in one wazero.Runtime at a time. Concurrent use +// of a wazero.Runtime is supported, but multiple runtimes must not share the +// same directory. +// +// Note: The embedder must safeguard this directory from external changes. +// +// Usage: +// +// ctx := experimental.WithCompilationCacheDirName(context.Background(), "/home/me/.cache/wazero") +// r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler()) +func WithCompilationCacheDirName(ctx context.Context, dirname string) context.Context { + if len(dirname) != 0 { + ctx = context.WithValue(ctx, compilationcache.FileCachePathKey{}, dirname) + } + return ctx +} diff --git a/internal/compilationcache/compilationcache.go b/internal/compilationcache/compilationcache.go new file mode 100644 index 0000000000..54954d571e --- /dev/null +++ b/internal/compilationcache/compilationcache.go @@ -0,0 +1,42 @@ +package compilationcache + +import ( + "crypto/sha256" + "io" +) + +// Cache allows the compiler engine to skip compilation of wasm to machine code +// where doing so is redundant for the same wasm binary and version of wazero. +// +// This augments the default in-memory cache of compiled functions, by +// decoupling it from a wazero.Runtime instance. Concretely, a runtime loses +// its cache once closed. This cache allows the runtime to rebuild its +// in-memory cache quicker, significantly reducing first-hit penalty on a hit. +// +// See NewFileCache for the example implementation. +type Cache interface { + // Get is called when the runtime is trying to get the cached compiled functions. + // Implementations are supposed to return compiled function in io.Reader with ok=true + // if the key exists on the cache. In the case of not-found, this should return + // ok=false with err=nil. content.Close() is automatically called by + // the caller of this Get. + // + // Note: the returned content won't go through the validation pass of Wasm binary + // which is applied when the binary is compiled from scratch without cache hit. + Get(key Key) (content io.ReadCloser, ok bool, err error) + // + // Add is called when the runtime is trying to add the new cache entry. + // The given `content` must be un-modified, and returned as-is in Get method. + // + // Note: the `content` is ensured to be safe through the validation phase applied on the Wasm binary. + Add(key Key, content io.Reader) (err error) + // + // Delete is called when the cache on the `key` returned by Get is no longer usable, and + // must be purged. Specifically, this is called happens when the wazero's version has been changed. + // For example, that is when there's a difference between the version of compiling wazero and the + // version of the currently used wazero. + Delete(key Key) (err error) +} + +// Key represents the 256-bit unique identifier assigned to each cache entry. +type Key = [sha256.Size]byte diff --git a/internal/compilationcache/file_cache.go b/internal/compilationcache/file_cache.go new file mode 100644 index 0000000000..fc7e0116d4 --- /dev/null +++ b/internal/compilationcache/file_cache.go @@ -0,0 +1,99 @@ +package compilationcache + +import ( + "context" + "encoding/hex" + "errors" + "io" + "os" + "path" + "sync" +) + +// FileCachePathKey is a context.Context Value key. Its value is a string +// representing the compilation cache directory. +type FileCachePathKey struct{} + +// NewFileCache returns a new Cache implemented by fileCache. +func NewFileCache(ctx context.Context) Cache { + if fsValue := ctx.Value(FileCachePathKey{}); fsValue != nil { + return newFileCache(fsValue.(string)) + } + return nil +} + +func newFileCache(dir string) *fileCache { + return &fileCache{dirPath: dir} +} + +// fileCache persists compiled functions into dirPath. +// +// Note: this can be expanded to do binary signing/verification, set TTL on each entry, etc. +type fileCache struct { + dirPath string + mux sync.RWMutex +} + +type fileReadCloser struct { + *os.File + fc *fileCache +} + +func (fc *fileCache) path(key Key) string { + return path.Join(fc.dirPath, hex.EncodeToString(key[:])) +} + +func (fc *fileCache) Get(key Key) (content io.ReadCloser, ok bool, err error) { + // TODO: take lock per key for more efficiency vs the complexity of impl. + fc.mux.RLock() + unlock := fc.mux.RUnlock + defer func() { + if unlock != nil { + unlock() + } + }() + + f, err := os.Open(fc.path(key)) + if errors.Is(err, os.ErrNotExist) { + return nil, false, nil + } else if err != nil { + return nil, false, err + } else { + // Unlock is done inside the content.Close() at the call site. + unlock = nil + return &fileReadCloser{File: f, fc: fc}, true, nil + } +} + +// Close wraps the os.File Close to release the read lock on fileCache. +func (f *fileReadCloser) Close() (err error) { + defer f.fc.mux.RUnlock() + err = f.File.Close() + return +} + +func (fc *fileCache) Add(key Key, content io.Reader) (err error) { + // TODO: take lock per key for more efficiency vs the complexity of impl. + fc.mux.Lock() + defer fc.mux.Unlock() + + file, err := os.Create(fc.path(key)) + if err != nil { + return + } + defer file.Close() + _, err = io.Copy(file, content) + return +} + +func (fc *fileCache) Delete(key Key) (err error) { + // TODO: take lock per key for more efficiency vs the complexity of impl. + fc.mux.Lock() + defer fc.mux.Unlock() + + err = os.Remove(fc.path(key)) + if errors.Is(err, os.ErrNotExist) { + err = nil + } + return +} diff --git a/internal/compilationcache/file_cache_test.go b/internal/compilationcache/file_cache_test.go new file mode 100644 index 0000000000..84e508ea3a --- /dev/null +++ b/internal/compilationcache/file_cache_test.go @@ -0,0 +1,135 @@ +package compilationcache + +import ( + "bytes" + "io" + "os" + "testing" + + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func TestFileReadCloser_Close(t *testing.T) { + fc := newFileCache(t.TempDir()) + key := Key{1, 2, 3} + + err := fc.Add(key, bytes.NewReader([]byte{1, 2, 3, 4})) + require.NoError(t, err) + + c, ok, err := fc.Get(key) + require.NoError(t, err) + require.True(t, ok) + + // At this point, file is not closed, therefore TryLock should fail. + require.False(t, fc.mux.TryLock()) + + // Close, and then TryLock should succeed this time. + require.NoError(t, c.Close()) + require.True(t, fc.mux.TryLock()) +} + +func TestFileCache_Add(t *testing.T) { + fc := newFileCache(t.TempDir()) + + t.Run("not exist", func(t *testing.T) { + content := []byte{1, 2, 3, 4, 5} + id := Key{1, 2, 3, 4, 5, 6, 7} + err := fc.Add(id, bytes.NewReader(content)) + require.NoError(t, err) + + // Ensures that file exists. + cached, err := os.ReadFile(fc.path(id)) + require.NoError(t, err) + + // Check if the saved content is the same as the given one. + require.Equal(t, content, cached) + }) + + t.Run("already exists", func(t *testing.T) { + content := []byte{1, 2, 3, 4, 5} + + id := Key{1, 2, 3} + + // Writes the pre-existing file for the same ID. + p := fc.path(id) + f, err := os.Create(p) + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + + err = fc.Add(id, bytes.NewReader(content)) + require.NoError(t, err) + + // Ensures that file exists. + cached, err := os.ReadFile(fc.path(id)) + require.NoError(t, err) + + // Check if the saved content is the same as the given one. + require.Equal(t, content, cached) + }) +} + +func TestFileCache_Delete(t *testing.T) { + fc := newFileCache(t.TempDir()) + t.Run("non-exist", func(t *testing.T) { + id := Key{0} + err := fc.Delete(id) + require.NoError(t, err) + }) + t.Run("exist", func(t *testing.T) { + id := Key{1, 2, 3} + p := fc.path(id) + f, err := os.Create(p) + require.NoError(t, err) + require.NoError(t, f.Close()) + + // Ensures that file exists now. + f, err = os.Open(p) + require.NoError(t, err) + require.NoError(t, f.Close()) + + // Delete the cache. + err = fc.Delete(id) + require.NoError(t, err) + + // Ensures that file no longer exists. + _, err = os.Open(p) + require.ErrorIs(t, err, os.ErrNotExist) + }) +} + +func TestFileCache_Get(t *testing.T) { + fc := newFileCache(t.TempDir()) + + t.Run("exist", func(t *testing.T) { + content := []byte{1, 2, 3, 4, 5} + id := Key{1, 2, 3} + + // Writes the pre-existing file for the ID. + p := fc.path(id) + f, err := os.Create(p) + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + + result, ok, err := fc.Get(id) + require.NoError(t, err) + require.True(t, ok) + defer func() { + require.NoError(t, result.Close()) + }() + + actual, err := io.ReadAll(result) + require.NoError(t, err) + + require.Equal(t, content, actual) + }) + t.Run("not exist", func(t *testing.T) { + _, ok, err := fc.Get(Key{0xf}) + // Non-exist should not be error. + require.NoError(t, err) + require.False(t, ok) + }) +} diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index 106d0fdf8d..ee9544c6db 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -10,7 +10,9 @@ import ( "unsafe" "github.com/tetratelabs/wazero/internal/buildoptions" + "github.com/tetratelabs/wazero/internal/compilationcache" "github.com/tetratelabs/wazero/internal/platform" + "github.com/tetratelabs/wazero/internal/version" "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/internal/wasmdebug" "github.com/tetratelabs/wazero/internal/wasmruntime" @@ -22,9 +24,11 @@ type ( engine struct { enabledFeatures wasm.Features codes map[wasm.ModuleID][]*code // guarded by mutex. + Cache compilationcache.Cache mux sync.RWMutex // setFinalizer defaults to runtime.SetFinalizer, but overridable for tests. - setFinalizer func(obj interface{}, finalizer interface{}) + setFinalizer func(obj interface{}, finalizer interface{}) + wazeroVersion string } // moduleEngine implements wasm.ModuleEngine @@ -411,8 +415,10 @@ func (e *engine) DeleteCompiledModule(module *wasm.Module) { // CompileModule implements the same method as documented on wasm.Engine. func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { - if _, ok := e.getCodes(module); ok { // cache hit! + if _, ok, err := e.getCodes(module); ok { // cache hit! return nil + } else if err != nil { + return err } funcs := make([]*code, 0, len(module.FunctionSection)) @@ -441,8 +447,7 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { funcs = append(funcs, compiled) } - e.addCodes(module, funcs) - return nil + return e.addCodes(module, funcs) } // NewModuleEngine implements the same method as documented on wasm.Engine. @@ -459,9 +464,11 @@ func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunct me.functions = append(me.functions, cf) } - codes, ok := e.getCodes(module) + codes, ok, err := e.getCodes(module) if !ok { return nil, fmt.Errorf("source module for %s must be compiled before instantiation", name) + } else if err != nil { + return nil, err } for i, c := range codes { @@ -485,25 +492,6 @@ func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunct return me, nil } -func (e *engine) deleteCodes(module *wasm.Module) { - e.mux.Lock() - defer e.mux.Unlock() - delete(e.codes, module.ID) -} - -func (e *engine) addCodes(module *wasm.Module, fs []*code) { - e.mux.Lock() - defer e.mux.Unlock() - e.codes[module.ID] = fs -} - -func (e *engine) getCodes(module *wasm.Module) (fs []*code, ok bool) { - e.mux.RLock() - defer e.mux.RUnlock() - fs, ok = e.codes[module.ID] - return -} - // Name implements the same method as documented on wasm.ModuleEngine. func (e *moduleEngine) Name() string { return e.name @@ -594,11 +582,17 @@ func NewEngine(ctx context.Context, enabledFeatures wasm.Features) wasm.Engine { return newEngine(ctx, enabledFeatures) } -func newEngine(_ context.Context, enabledFeatures wasm.Features) *engine { +func newEngine(ctx context.Context, enabledFeatures wasm.Features) *engine { + var wazeroVersion string + if v := ctx.Value(version.WazeroVersionKey{}); v != nil { + wazeroVersion = v.(string) + } return &engine{ enabledFeatures: enabledFeatures, codes: map[wasm.ModuleID][]*code{}, setFinalizer: runtime.SetFinalizer, + Cache: compilationcache.NewFileCache(ctx), + wazeroVersion: wazeroVersion, } } diff --git a/internal/engine/compiler/engine_cache.go b/internal/engine/compiler/engine_cache.go new file mode 100644 index 0000000000..2da9b4aace --- /dev/null +++ b/internal/engine/compiler/engine_cache.go @@ -0,0 +1,188 @@ +package compiler + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/tetratelabs/wazero/internal/platform" + "github.com/tetratelabs/wazero/internal/u32" + "github.com/tetratelabs/wazero/internal/u64" + "github.com/tetratelabs/wazero/internal/wasm" +) + +func (e *engine) deleteCodes(module *wasm.Module) { + e.mux.Lock() + defer e.mux.Unlock() + delete(e.codes, module.ID) + + // Note: we do not call e.Cache.Delete, as the lifetime of + // the content is up to the implementation of extencache.Cache interface. +} + +func (e *engine) addCodes(module *wasm.Module, codes []*code) (err error) { + e.addCodesToMemory(module, codes) + err = e.addCodesToCache(module, codes) + return +} + +func (e *engine) getCodes(module *wasm.Module) (codes []*code, ok bool, err error) { + codes, ok = e.getCodesFromMemory(module) + if ok { + return + } + codes, ok, err = e.getCodesFromCache(module) + if ok { + e.addCodesToMemory(module, codes) + } + return +} + +func (e *engine) addCodesToMemory(module *wasm.Module, codes []*code) { + e.mux.Lock() + defer e.mux.Unlock() + e.codes[module.ID] = codes +} + +func (e *engine) getCodesFromMemory(module *wasm.Module) (codes []*code, ok bool) { + e.mux.RLock() + defer e.mux.RUnlock() + codes, ok = e.codes[module.ID] + return +} + +func (e *engine) addCodesToCache(module *wasm.Module, codes []*code) (err error) { + if e.Cache == nil { + return + } + err = e.Cache.Add(module.ID, serializeCodes(e.wazeroVersion, codes)) + return +} + +func (e *engine) getCodesFromCache(module *wasm.Module) (codes []*code, hit bool, err error) { + if e.Cache == nil { + return + } + + // Check if the entries exist in the external cache. + var cached io.ReadCloser + cached, hit, err = e.Cache.Get(module.ID) + if !hit || err != nil { + return + } + defer cached.Close() + + // Otherwise, we hit the cache on external cache. + // We retrieve *code structures from `cached`. + var staleCache bool + codes, staleCache, err = deserializeCodes(e.wazeroVersion, cached) + if err != nil { + hit = false + return + } else if staleCache { + return nil, false, e.Cache.Delete(module.ID) + } + + for i, c := range codes { + c.indexInModule = wasm.Index(i) + c.sourceModule = module + } + return +} + +var ( + wazeroMagic = "WAZERO" + // version must be synced with the tag of the wazero library. + +) + +func serializeCodes(wazeroVersion string, codes []*code) io.Reader { + buf := bytes.NewBuffer(nil) + // First 6 byte: WAZERO header. + buf.WriteString(wazeroMagic) + // Next 1 byte: length of version: + buf.WriteByte(byte(len(wazeroVersion))) + // Version of wazero. + buf.WriteString(wazeroVersion) + // Number of *code (== locally defined functions in the module): 4 bytes. + buf.Write(u32.LeBytes(uint32(len(codes)))) + for _, c := range codes { + // The stack pointer ceil (8 bytes). + buf.Write(u64.LeBytes(c.stackPointerCeil)) + // The length of code segment (8 bytes). + buf.Write(u64.LeBytes(uint64(len(c.codeSegment)))) + // Append the native code. + buf.Write(c.codeSegment) + } + return bytes.NewReader(buf.Bytes()) +} + +func deserializeCodes(wazeroVersion string, reader io.Reader) (codes []*code, staleCache bool, err error) { + cacheHeaderSize := len(wazeroMagic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */ + + // Read the header before the native code. + header := make([]byte, cacheHeaderSize) + n, err := reader.Read(header) + if err != nil { + return nil, false, err + } + + if n != cacheHeaderSize { + return nil, false, fmt.Errorf("invalid header length: %d", n) + } + + // Check the version compatibility. + versionSize := int(header[len(wazeroMagic)]) + + cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize + if cachedVersionEnd >= len(header) { + staleCache = true + return + } else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion { + staleCache = true + return + } + + functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:]) + codes = make([]*code, 0, functionsNum) + + var eightBytes [8]byte + for i := uint32(0); i < functionsNum; i++ { + c := &code{} + + // Read the stack pointer ceil. + _, err = reader.Read(eightBytes[:]) + if err != nil { + err = fmt.Errorf("reading stack pointer ceil: %v", err) + break + } + + c.stackPointerCeil = binary.LittleEndian.Uint64(eightBytes[:]) + + // Read (and mmap) the native code. + _, err = reader.Read(eightBytes[:]) + if err != nil { + err = fmt.Errorf("reading native code size: %v", err) + break + } + + c.codeSegment, err = platform.MmapCodeSegment(reader, int(binary.LittleEndian.Uint64(eightBytes[:]))) + if err != nil { + err = fmt.Errorf("mmaping function: %v", err) + break + } + codes = append(codes, c) + } + + if err != nil { + for _, c := range codes { + if errMunmap := platform.MunmapCodeSegment(c.codeSegment); errMunmap != nil { + // Munmap failure shouldn't happen. + panic(errMunmap) + } + } + codes = nil + } + return +} diff --git a/internal/engine/compiler/engine_cache_test.go b/internal/engine/compiler/engine_cache_test.go new file mode 100644 index 0000000000..31d2abedd0 --- /dev/null +++ b/internal/engine/compiler/engine_cache_test.go @@ -0,0 +1,370 @@ +package compiler + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/tetratelabs/wazero/internal/testing/require" + "github.com/tetratelabs/wazero/internal/u32" + "github.com/tetratelabs/wazero/internal/u64" + "github.com/tetratelabs/wazero/internal/wasm" +) + +var testVersion string + +func concat(ins ...[]byte) (ret []byte) { + for _, in := range ins { + ret = append(ret, in...) + } + return +} + +func TestSerializeCodes(t *testing.T) { + tests := []struct { + in []*code + exp []byte + }{ + { + in: []*code{{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}}, + exp: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + ), + }, + { + in: []*code{ + {stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}, + {stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}}, + }, + exp: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + u64.LeBytes(0xffffffff), // stack pointer ceil. + u64.LeBytes(3), // length of code. + []byte{1, 2, 3}, // code. + ), + }, + } + + for i, tc := range tests { + actual, err := io.ReadAll(serializeCodes(testVersion, tc.in)) + require.NoError(t, err, i) + require.Equal(t, tc.exp, actual, i) + } +} + +func TestDeserializeCodes(t *testing.T) { + tests := []struct { + name string + in []byte + expCodes []*code + expStaleCache bool + expErr string + }{ + { + + name: "invalid header", + in: []byte{1}, + expErr: "invalid header length: 1", + }, + { + + name: "version mismatch", + in: concat( + []byte(wazeroMagic), + []byte{byte(len("1233123.1.1"))}, + []byte("1233123.1.1"), + u32.LeBytes(1), // number of functions. + ), + expStaleCache: true, + }, + { + + name: "version mismatch", + in: concat( + []byte(wazeroMagic), + []byte{byte(len("1"))}, + []byte("1"), + u32.LeBytes(1), // number of functions. + ), + expStaleCache: true, + }, + { + name: "one function", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(1), // number of functions. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + ), + expCodes: []*code{ + {stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}, + }, + expStaleCache: false, + expErr: "", + }, + { + name: "two functions", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + u64.LeBytes(0xffffffff), // stack pointer ceil. + u64.LeBytes(3), // length of code. + []byte{1, 2, 3}, // code. + ), + expCodes: []*code{ + {stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}, + {stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}}, + }, + expStaleCache: false, + expErr: "", + }, + { + name: "reading stack pointer", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + ), + expErr: "reading stack pointer ceil: EOF", + }, + { + name: "reading native code size", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + u64.LeBytes(12345), // stack pointer ceil. + ), + expErr: "reading native code size: EOF", + }, + { + name: "mmapping", + in: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + // Lack of code here. + ), + expErr: "mmaping function: EOF", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + codes, staleCache, err := deserializeCodes(testVersion, bytes.NewReader(tc.in)) + if tc.expErr != "" { + require.EqualError(t, err, tc.expErr) + } else { + require.NoError(t, err) + } + + require.Equal(t, tc.expCodes, codes) + require.Equal(t, tc.expStaleCache, staleCache) + }) + } +} + +func TestEngine_getCodesFromCache(t *testing.T) { + tests := []struct { + name string + ext *testCache + key wasm.ModuleID + expCodes []*code + expHit bool + expErr string + expDeleted bool + }{ + {name: "extern cache not given"}, + { + name: "not hit", + ext: &testCache{caches: map[wasm.ModuleID][]byte{}}, + }, + { + name: "error in Cache.Get", + ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: {}}}, + expErr: "some error from extern cache", + }, + { + name: "error in deserialization", + ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: {1, 2, 3}}}, + expErr: "invalid header length: 3", + }, + { + name: "stale cache", + ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: concat( + []byte(wazeroMagic), + []byte{byte(len("1233123.1.1"))}, + []byte("1233123.1.1"), + u32.LeBytes(1), // number of functions. + )}}, + expDeleted: true, + }, + { + name: "hit", + ext: &testCache{caches: map[wasm.ModuleID][]byte{ + {}: concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(2), // number of functions. + // Function index = 0. + u64.LeBytes(12345), // stack pointer ceil. + u64.LeBytes(5), // length of code. + []byte{1, 2, 3, 4, 5}, // code. + // Function index = 1. + u64.LeBytes(0xffffffff), // stack pointer ceil. + u64.LeBytes(3), // length of code. + []byte{1, 2, 3}, // code. + ), + }}, + expHit: true, + expCodes: []*code{ + {stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}, indexInModule: 0}, + {stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}, indexInModule: 1}, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + m := &wasm.Module{ID: tc.key} + for _, expC := range tc.expCodes { + expC.sourceModule = m + } + + e := engine{} + if tc.ext != nil { + e.Cache = tc.ext + } + + codes, hit, err := e.getCodesFromCache(m) + if tc.expErr != "" { + require.EqualError(t, err, tc.expErr) + } else { + require.NoError(t, err) + } + + require.Equal(t, tc.expHit, hit) + require.Equal(t, tc.expCodes, codes) + + if tc.expDeleted { + require.Equal(t, tc.ext.deleted, tc.key) + } + }) + } +} + +func TestEngine_addCodesToCache(t *testing.T) { + t.Run("not defined", func(t *testing.T) { + e := engine{} + err := e.addCodesToCache(nil, nil) + require.NoError(t, err) + }) + t.Run("add", func(t *testing.T) { + ext := &testCache{caches: map[wasm.ModuleID][]byte{}} + e := engine{Cache: ext} + m := &wasm.Module{} + codes := []*code{{stackPointerCeil: 123, codeSegment: []byte{1, 2, 3}}} + err := e.addCodesToCache(m, codes) + require.NoError(t, err) + + content, ok := ext.caches[m.ID] + require.True(t, ok) + require.Equal(t, concat( + []byte(wazeroMagic), + []byte{byte(len(testVersion))}, + []byte(testVersion), + u32.LeBytes(1), // number of functions. + u64.LeBytes(123), // stack pointer ceil. + u64.LeBytes(3), // length of code. + []byte{1, 2, 3}, // code. + ), content) + }) +} + +// testCache implements compilationcache.Cache +type testCache struct { + caches map[wasm.ModuleID][]byte + deleted wasm.ModuleID +} + +// Get implements compilationcache.Cache Get +func (tc *testCache) Get(key wasm.ModuleID) (content io.ReadCloser, ok bool, err error) { + var raw []byte + raw, ok = tc.caches[key] + if !ok { + return + } + + if len(raw) == 0 { + ok = false + err = fmt.Errorf("some error from extern cache") + return + } + + content = io.NopCloser(bytes.NewReader(raw)) + return +} + +// Add implements compilationcache.Cache Add +func (tc *testCache) Add(key wasm.ModuleID, content io.Reader) (err error) { + raw, err := io.ReadAll(content) + if err != nil { + return err + } + tc.caches[key] = raw + return +} + +// Delete implements compilationcache.Cache Delete +func (tc *testCache) Delete(key wasm.ModuleID) (err error) { + tc.deleted = key + return +} diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index bde9a885dd..1be1c3146c 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -284,29 +284,6 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) { } } -// TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and compiler -func TestEngine_Cachedcodes(t *testing.T) { - e := newEngine(context.Background(), wasm.Features20191205) - exp := []*code{ - {codeSegment: []byte{0x0}}, - {codeSegment: []byte{0x0}}, - } - m := &wasm.Module{} - - e.addCodes(m, exp) - - actual, ok := e.getCodes(m) - require.True(t, ok) - require.Equal(t, len(exp), len(actual)) - for i := range actual { - require.Equal(t, exp[i], actual[i]) - } - - e.deleteCodes(m) - _, ok = e.getCodes(m) - require.False(t, ok) -} - func TestCallEngine_builtinFunctionTableGrow(t *testing.T) { ce := &callEngine{ valueStack: []uint64{ diff --git a/internal/engine/compiler/impl_amd64.go b/internal/engine/compiler/impl_amd64.go index bcb71a83cd..919f744df0 100644 --- a/internal/engine/compiler/impl_amd64.go +++ b/internal/engine/compiler/impl_amd64.go @@ -7,6 +7,7 @@ package compiler // e.g. MOVQ will be given as amd64.MOVQ. import ( + "bytes" "fmt" "math" "runtime" @@ -191,10 +192,7 @@ func (c *amd64Compiler) compile() (code []byte, stackPointerCeil uint64, err err return } - code, err = platform.MmapCodeSegment(code) - if err != nil { - return - } + code, err = platform.MmapCodeSegment(bytes.NewReader(code), len(code)) return } diff --git a/internal/engine/compiler/impl_arm64.go b/internal/engine/compiler/impl_arm64.go index 0a38e7d9f4..a346d8ba44 100644 --- a/internal/engine/compiler/impl_arm64.go +++ b/internal/engine/compiler/impl_arm64.go @@ -4,6 +4,7 @@ package compiler import ( + "bytes" "errors" "fmt" "math" @@ -112,10 +113,7 @@ func (c *arm64Compiler) compile() (code []byte, stackPointerCeil uint64, err err return } - code, err = platform.MmapCodeSegment(original) - if err != nil { - return - } + code, err = platform.MmapCodeSegment(bytes.NewReader(original), len(original)) return } diff --git a/internal/integration_test/bench/bench_test.go b/internal/integration_test/bench/bench_test.go index bf3634ba0a..6c3757255a 100644 --- a/internal/integration_test/bench/bench_test.go +++ b/internal/integration_test/bench/bench_test.go @@ -10,6 +10,8 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/wasi_snapshot_preview1" ) @@ -50,11 +52,39 @@ func BenchmarkInitialization(b *testing.B) { } } -func runInitializationBench(b *testing.B, r wazero.Runtime) { +func BenchmarkCompilation(b *testing.B) { + if !platform.CompilerSupported() { + b.Skip() + } + + // Note: recreate runtime each time in the loop to ensure that + // recompilation happens if the extern cache is not used. + b.Run("with extern cache", func(b *testing.B) { + ctx := experimental.WithCompilationCacheDirName(context.Background(), b.TempDir()) + for i := 0; i < b.N; i++ { + r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler()) + runCompilation(b, r) + } + }) + b.Run("without extern cache", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigCompiler()) + runCompilation(b, r) + } + }) +} + +func runCompilation(b *testing.B, r wazero.Runtime) wazero.CompiledModule { compiled, err := r.CompileModule(testCtx, caseWasm, wazero.NewCompileConfig()) if err != nil { b.Fatal(err) } + return compiled +} + +func runInitializationBench(b *testing.B, r wazero.Runtime) { + compiled := runCompilation(b, r) defer compiled.Close(testCtx) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/platform/buf_writer.go b/internal/platform/buf_writer.go new file mode 100644 index 0000000000..5d53d16788 --- /dev/null +++ b/internal/platform/buf_writer.go @@ -0,0 +1,20 @@ +package platform + +// bufWriter implements io.Writer. +// +// This is implemented because bytes.Buffer cannot write from the beginning of the underlying buffer +// without changing the memory location. In this case, the underlying buffer is memory-mapped region, +// and we have to write into that region via io.Copy since sometimes the original native code exists +// as a file for external-cached cases. +type bufWriter struct { + underlying []byte + pos int +} + +// Write implements io.Writer Write. +func (b *bufWriter) Write(p []byte) (n int, err error) { + copy(b.underlying[b.pos:], p) + n = len(p) + b.pos += n + return +} diff --git a/internal/platform/mmap.go b/internal/platform/mmap.go index 3d11e1424a..70c2cdd120 100644 --- a/internal/platform/mmap.go +++ b/internal/platform/mmap.go @@ -4,6 +4,7 @@ package platform import ( + "io" "syscall" "unsafe" ) @@ -14,11 +15,11 @@ func munmapCodeSegment(code []byte) error { // mmapCodeSegmentAMD64 gives all read-write-exec permission to the mmap region // to enter the function. Otherwise, segmentation fault exception is raised. -func mmapCodeSegmentAMD64(code []byte) ([]byte, error) { +func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { mmapFunc, err := syscall.Mmap( -1, 0, - len(code), + size, // The region must be RWX: RW for writing native codes, X for executing the region. syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC, // Anonymous as this is not an actual file, but a memory, @@ -28,19 +29,21 @@ func mmapCodeSegmentAMD64(code []byte) ([]byte, error) { if err != nil { return nil, err } - copy(mmapFunc, code) - return mmapFunc, nil + + w := &bufWriter{underlying: mmapFunc} + _, err = io.CopyN(w, code, int64(size)) + return mmapFunc, err } // mmapCodeSegmentARM64 cannot give all read-write-exec permission to the mmap region. // Otherwise, the mmap systemcall would raise an error. Here we give read-write // to the region at first, write the native code and then change the perm to -// read-exec so we can execute the native code. -func mmapCodeSegmentARM64(code []byte) ([]byte, error) { +// read-exec, so we can execute the native code. +func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { mmapFunc, err := syscall.Mmap( -1, 0, - len(code), + size, // The region must be RW: RW for writing native codes. syscall.PROT_READ|syscall.PROT_WRITE, // Anonymous as this is not an actual file, but a memory, @@ -51,7 +54,11 @@ func mmapCodeSegmentARM64(code []byte) ([]byte, error) { return nil, err } - copy(mmapFunc, code) + w := &bufWriter{underlying: mmapFunc} + _, err = io.CopyN(w, code, int64(size)) + if err != nil { + return nil, err + } // Then we're done with writing code, change the permission to RX. err = mprotect(mmapFunc, syscall.PROT_READ|syscall.PROT_EXEC) diff --git a/internal/platform/mmap_test.go b/internal/platform/mmap_test.go index ac7af344e0..1747748ad1 100644 --- a/internal/platform/mmap_test.go +++ b/internal/platform/mmap_test.go @@ -1,6 +1,7 @@ package platform import ( + "bytes" "crypto/rand" "io" "testing" @@ -8,22 +9,23 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -var testCode, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024)) +var testCodeBuf, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024)) func Test_MmapCodeSegment(t *testing.T) { if !CompilerSupported() { t.Skip() } - newCode, err := MmapCodeSegment(testCode) + testCodeReader := bytes.NewReader(testCodeBuf) + newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len()) require.NoError(t, err) // Verify that the mmap is the same as the original. - require.Equal(t, testCode, newCode) + require.Equal(t, testCodeBuf, newCode) // TODO: test newCode can executed. t.Run("panic on zero length", func(t *testing.T) { captured := require.CapturePanic(func() { - _, _ = MmapCodeSegment(make([]byte, 0)) + _, _ = MmapCodeSegment(bytes.NewBuffer(make([]byte, 0)), 0) }) require.EqualError(t, captured, "BUG: MmapCodeSegment with zero length") }) @@ -35,9 +37,10 @@ func Test_MunmapCodeSegment(t *testing.T) { } // Errors if never mapped - require.Error(t, MunmapCodeSegment(testCode)) + require.Error(t, MunmapCodeSegment(testCodeBuf)) - newCode, err := MmapCodeSegment(testCode) + testCodeReader := bytes.NewReader(testCodeBuf) + newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len()) require.NoError(t, err) // First munmap should succeed. require.NoError(t, MunmapCodeSegment(newCode)) diff --git a/internal/platform/mmap_windows.go b/internal/platform/mmap_windows.go index ecb40dd7d4..1dd88dcca2 100644 --- a/internal/platform/mmap_windows.go +++ b/internal/platform/mmap_windows.go @@ -4,6 +4,7 @@ package platform import ( "fmt" + "io" "reflect" "syscall" "unsafe" @@ -30,9 +31,8 @@ func munmapCodeSegment(code []byte) error { // allocateMemory commits the memory region via the "VirtualAlloc" function. // See https://docs.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-virtualalloc -func allocateMemory(code []byte, protect uintptr) (uintptr, error) { +func allocateMemory(size uintptr, protect uintptr) (uintptr, error) { address := uintptr(0) // TODO: document why zero - size := uintptr(len(code)) alloctype := windows_MEM_COMMIT if r, _, err := procVirtualAlloc.Call(address, size, alloctype, protect); r == 0 { return 0, fmt.Errorf("compiler: VirtualAlloc error: %w", ensureErr(err)) @@ -60,8 +60,8 @@ func virtualProtect(address, size, newprotect uintptr, oldprotect *uint32) error return nil } -func mmapCodeSegmentAMD64(code []byte) ([]byte, error) { - p, err := allocateMemory(code, windows_PAGE_EXECUTE_READWRITE) +func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { + p, err := allocateMemory(uintptr(size), windows_PAGE_EXECUTE_READWRITE) if err != nil { return nil, err } @@ -69,14 +69,16 @@ func mmapCodeSegmentAMD64(code []byte) ([]byte, error) { var mem []byte sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) sh.Data = p - sh.Len = len(code) - sh.Cap = len(code) - copy(mem, code) - return mem, nil + sh.Len = size + sh.Cap = size + + w := &bufWriter{underlying: mem} + _, err = io.CopyN(w, code, int64(size)) + return mem, err } -func mmapCodeSegmentARM64(code []byte) ([]byte, error) { - p, err := allocateMemory(code, windows_PAGE_READWRITE) +func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { + p, err := allocateMemory(uintptr(size), windows_PAGE_READWRITE) if err != nil { return nil, err } @@ -84,12 +86,16 @@ func mmapCodeSegmentARM64(code []byte) ([]byte, error) { var mem []byte sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) sh.Data = p - sh.Len = len(code) - sh.Cap = len(code) - copy(mem, code) + sh.Len = size + sh.Cap = size + w := &bufWriter{underlying: mem} + _, err = io.CopyN(w, code, int64(size)) + if err != nil { + return nil, err + } old := uint32(windows_PAGE_READWRITE) - err = virtualProtect(p, uintptr(len(code)), windows_PAGE_EXECUTE_READ, &old) + err = virtualProtect(p, uintptr(size), windows_PAGE_EXECUTE_READ, &old) if err != nil { return nil, err } diff --git a/internal/platform/platform.go b/internal/platform/platform.go index 45e924aee0..03ef17f6ba 100644 --- a/internal/platform/platform.go +++ b/internal/platform/platform.go @@ -6,6 +6,7 @@ package platform import ( "errors" + "io" "runtime" ) @@ -29,14 +30,14 @@ func CompilerSupported() bool { // MmapCodeSegment copies the code into the executable region and returns the byte slice of the region. // // See https://man7.org/linux/man-pages/man2/mmap.2.html for mmap API and flags. -func MmapCodeSegment(code []byte) ([]byte, error) { - if len(code) == 0 { +func MmapCodeSegment(code io.Reader, size int) ([]byte, error) { + if size == 0 { panic(errors.New("BUG: MmapCodeSegment with zero length")) } if runtime.GOARCH == "amd64" { - return mmapCodeSegmentAMD64(code) + return mmapCodeSegmentAMD64(code, size) } else { - return mmapCodeSegmentARM64(code) + return mmapCodeSegmentARM64(code, size) } } diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000000..a3651feb33 --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,4 @@ +package version + +// WazeroVersionKey is the key for holding wazero's version in context.Context. +type WazeroVersionKey struct{} diff --git a/runtime.go b/runtime.go index 90138930fa..80232a9463 100644 --- a/runtime.go +++ b/runtime.go @@ -7,6 +7,7 @@ import ( "github.com/tetratelabs/wazero/api" experimentalapi "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/internal/version" "github.com/tetratelabs/wazero/internal/wasm" binaryformat "github.com/tetratelabs/wazero/internal/wasm/binary" ) @@ -126,6 +127,9 @@ func NewRuntime(ctx context.Context) Runtime { // NewRuntimeWithConfig returns a runtime with the given configuration. func NewRuntimeWithConfig(ctx context.Context, rConfig RuntimeConfig) Runtime { + if v := ctx.Value(version.WazeroVersionKey{}); v == nil { + ctx = context.WithValue(ctx, version.WazeroVersionKey{}, wazeroVersion) + } config := rConfig.(*runtimeConfig) store, ns := wasm.NewStore(config.enabledFeatures, config.newEngine(ctx, config.enabledFeatures)) return &runtime{ diff --git a/runtime_test.go b/runtime_test.go index 15a41b814a..22b4142461 100644 --- a/runtime_test.go +++ b/runtime_test.go @@ -9,6 +9,7 @@ import ( "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" + "github.com/tetratelabs/wazero/internal/version" "github.com/tetratelabs/wazero/internal/wasm" binaryformat "github.com/tetratelabs/wazero/internal/wasm/binary" "github.com/tetratelabs/wazero/sys" @@ -20,6 +21,19 @@ var ( testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") ) +func TestNewRuntimeWithConfig_version(t *testing.T) { + cfg := NewRuntimeConfig().(*runtimeConfig) + oldNewEngine := cfg.newEngine + cfg.newEngine = func(ctx context.Context, features wasm.Features) wasm.Engine { + // Ensures that wazeroVersion is propagated to the engine. + v := ctx.Value(version.WazeroVersionKey{}) + require.NotNil(t, v) + require.Equal(t, wazeroVersion, v.(string)) + return oldNewEngine(ctx, features) + } + _ = NewRuntimeWithConfig(testCtx, cfg) +} + func TestRuntime_CompileModule(t *testing.T) { tests := []struct { name string diff --git a/version.go b/version.go new file mode 100644 index 0000000000..3c0d29a6a8 --- /dev/null +++ b/version.go @@ -0,0 +1,5 @@ +package wazero + +// wazeroVersion holds the current version of wazero. +// TODO: use debug.ReadBuildInfo automatically set wazeroVersion to the release tag. +var wazeroVersion = "dev"