diff --git a/crates/tests/tests/custom_sections.rs b/crates/tests/tests/custom_sections.rs index 51e93236..9ff85036 100644 --- a/crates/tests/tests/custom_sections.rs +++ b/crates/tests/tests/custom_sections.rs @@ -86,11 +86,7 @@ fn smoke_test_code_transform() { APPLIED_CODE_TRANSFORM.store(1, Ordering::SeqCst); assert!(!transform.is_empty()); for (input_offset, output_offset) in transform.iter().cloned() { - println!("0x{:x} -> 0x{:x}", input_offset, output_offset); - - // TODO: this assert currently fails - // - // assert_eq!(input_offset + 3, output_offset); + assert_eq!(input_offset + 3, output_offset); } } } @@ -113,15 +109,6 @@ fn smoke_test_code_transform() { module.emit_wasm() }; - use std::fs; - use std::path::Path; - - fs::write("input.wasm", &wasm).unwrap(); - println!( - "=== input.wasm ===\n {}", - walrus_tests_utils::wasm2wat(Path::new("input.wasm")) - ); - config.preserve_code_transform(true); let mut module = config.parse(&wasm).unwrap(); @@ -138,15 +125,7 @@ fn smoke_test_code_transform() { // Emit the new, transformed wasm. This should trigger the // `apply_code_transform` method to be called. - let wasm = module.emit_wasm(); - - fs::write("output.wasm", &wasm).unwrap(); - println!( - "=== output.wasm ===\n {}", - walrus_tests_utils::wasm2wat(Path::new("output.wasm")) - ); + let _wasm = module.emit_wasm(); assert_eq!(APPLIED_CODE_TRANSFORM.load(Ordering::SeqCst), 1); - - panic!("TODO: make the commented out assertion in `apply_code_transform` pass"); } diff --git a/src/module/functions/local_function/emit.rs b/src/module/functions/local_function/emit.rs index f4bd9b7c..478c7348 100644 --- a/src/module/functions/local_function/emit.rs +++ b/src/module/functions/local_function/emit.rs @@ -61,18 +61,22 @@ impl Emit<'_, '_> { let old = self.id; self.id = id; - if let Some(map) = self.map.as_mut() { - map.insert(old, self.encoder.pos()); - } + // The match below provides position where the Expr was + // encoded. + let mut encoded_at = None; match self.func.get(id) { - Const(e) => e.value.emit(self.encoder), + Const(e) => { + encoded_at = Some(self.encoder.pos()); + e.value.emit(self.encoder); + } Block(e) => self.visit_block(e), BrTable(e) => self.visit_br_table(e), IfElse(e) => self.visit_if_else(e), Drop(e) => { self.visit(e.expr); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x1a); // drop } @@ -80,6 +84,7 @@ impl Emit<'_, '_> { for x in e.values.iter() { self.visit(*x); } + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x0f); // return } @@ -95,6 +100,7 @@ impl Emit<'_, '_> { MemorySize(e) => { let idx = self.indices.get_memory_index(e.memory); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x3f); // memory.size self.encoder.u32(idx); } @@ -102,6 +108,7 @@ impl Emit<'_, '_> { MemoryGrow(e) => { self.visit(e.pages); let idx = self.indices.get_memory_index(e.memory); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x40); // memory.grow self.encoder.u32(idx); } @@ -110,6 +117,7 @@ impl Emit<'_, '_> { self.visit(e.memory_offset); self.visit(e.data_offset); self.visit(e.len); + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x08]); // memory.init let idx = self.indices.get_data_index(e.data); self.encoder.u32(idx); @@ -119,6 +127,7 @@ impl Emit<'_, '_> { } DataDrop(e) => { + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x09]); // data.drop let idx = self.indices.get_data_index(e.data); self.encoder.u32(idx); @@ -128,6 +137,7 @@ impl Emit<'_, '_> { self.visit(e.dst_offset); self.visit(e.src_offset); self.visit(e.len); + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x0a]); // memory.copy let idx = self.indices.get_memory_index(e.src); assert_eq!(idx, 0); @@ -141,6 +151,7 @@ impl Emit<'_, '_> { self.visit(e.offset); self.visit(e.value); self.visit(e.len); + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x0b]); // memory.fill let idx = self.indices.get_memory_index(e.memory); assert_eq!(idx, 0); @@ -153,6 +164,7 @@ impl Emit<'_, '_> { self.visit(e.lhs); self.visit(e.rhs); + encoded_at = Some(self.encoder.pos()); match e.op { I32Eq => self.encoder.byte(0x46), I32Ne => self.encoder.byte(0x47), @@ -347,6 +359,8 @@ impl Emit<'_, '_> { use crate::ir::UnaryOp::*; self.visit(e.expr); + + encoded_at = Some(self.encoder.pos()); match e.op { I32Eqz => self.encoder.byte(0x45), I32Clz => self.encoder.byte(0x67), @@ -493,10 +507,12 @@ impl Emit<'_, '_> { self.visit(e.alternative); self.visit(e.consequent); self.visit(e.condition); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x1b); // select } Unreachable(_) => { + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x00); // unreachable } @@ -505,6 +521,7 @@ impl Emit<'_, '_> { self.visit(*x); } let target = self.branch_target(e.block); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x0c); // br self.encoder.u32(target); } @@ -515,6 +532,7 @@ impl Emit<'_, '_> { } self.visit(e.condition); let target = self.branch_target(e.block); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x0d); // br_if self.encoder.u32(target); } @@ -524,6 +542,7 @@ impl Emit<'_, '_> { self.visit(*x); } let idx = self.indices.get_func_index(e.func); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x10); // call self.encoder.u32(idx); } @@ -535,6 +554,7 @@ impl Emit<'_, '_> { self.visit(e.func); let idx = self.indices.get_type_index(e.ty); let table = self.indices.get_table_index(e.table); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x11); // call_indirect self.encoder.u32(idx); self.encoder.u32(table); @@ -542,6 +562,7 @@ impl Emit<'_, '_> { LocalGet(e) => { let idx = self.local_indices[&e.local]; + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x20); // local.get self.encoder.u32(idx); } @@ -549,6 +570,7 @@ impl Emit<'_, '_> { LocalSet(e) => { self.visit(e.value); let idx = self.local_indices[&e.local]; + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x21); // local.set self.encoder.u32(idx); } @@ -556,12 +578,14 @@ impl Emit<'_, '_> { LocalTee(e) => { self.visit(e.value); let idx = self.local_indices[&e.local]; + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x22); // local.tee self.encoder.u32(idx); } GlobalGet(e) => { let idx = self.indices.get_global_index(e.global); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x23); // global.get self.encoder.u32(idx); } @@ -569,6 +593,7 @@ impl Emit<'_, '_> { GlobalSet(e) => { self.visit(e.value); let idx = self.indices.get_global_index(e.global); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x24); // global.set self.encoder.u32(idx); } @@ -577,6 +602,7 @@ impl Emit<'_, '_> { use crate::ir::ExtendedLoad::*; use crate::ir::LoadKind::*; self.visit(e.address); + encoded_at = Some(self.encoder.pos()); match e.kind { I32 { atomic: false } => self.encoder.byte(0x28), // i32.load I32 { atomic: true } => self.encoder.raw(&[0xfe, 0x10]), // i32.atomic.load @@ -618,6 +644,7 @@ impl Emit<'_, '_> { use crate::ir::StoreKind::*; self.visit(e.address); self.visit(e.value); + encoded_at = Some(self.encoder.pos()); match e.kind { I32 { atomic: false } => self.encoder.byte(0x36), // i32.store I32 { atomic: true } => self.encoder.raw(&[0xfe, 0x17]), // i32.atomic.store @@ -647,6 +674,7 @@ impl Emit<'_, '_> { self.visit(e.address); self.visit(e.value); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xfe); self.encoder.byte(match (e.op, e.width) { (Add, I32) => 0x1e, @@ -708,6 +736,7 @@ impl Emit<'_, '_> { self.visit(e.expected); self.visit(e.replacement); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xfe); self.encoder.byte(match e.width { I32 => 0x48, @@ -726,6 +755,7 @@ impl Emit<'_, '_> { self.visit(e.address); self.visit(e.count); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xfe); self.encoder.byte(0x00); self.memarg(e.memory, &e.arg); @@ -736,6 +766,7 @@ impl Emit<'_, '_> { self.visit(e.expected); self.visit(e.timeout); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xfe); self.encoder.byte(if e.sixty_four { 0x02 } else { 0x01 }); self.memarg(e.memory, &e.arg); @@ -743,6 +774,7 @@ impl Emit<'_, '_> { TableGet(e) => { self.visit(e.index); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x25); let idx = self.indices.get_table_index(e.table); self.encoder.u32(idx); @@ -750,6 +782,7 @@ impl Emit<'_, '_> { TableSet(e) => { self.visit(e.index); self.visit(e.value); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0x26); let idx = self.indices.get_table_index(e.table); self.encoder.u32(idx); @@ -757,20 +790,24 @@ impl Emit<'_, '_> { TableGrow(e) => { self.visit(e.value); self.visit(e.amount); + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x0f]); let idx = self.indices.get_table_index(e.table); self.encoder.u32(idx); } TableSize(e) => { + encoded_at = Some(self.encoder.pos()); self.encoder.raw(&[0xfc, 0x10]); let idx = self.indices.get_table_index(e.table); self.encoder.u32(idx); } RefNull(_e) => { + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xd0); } RefIsNull(e) => { self.visit(e.value); + encoded_at = Some(self.encoder.pos()); self.encoder.byte(0xd1); } @@ -778,16 +815,23 @@ impl Emit<'_, '_> { self.visit(e.v1); self.visit(e.v2); self.visit(e.mask); + encoded_at = Some(self.encoder.pos()); self.simd(0x50); } V128Shuffle(e) => { self.visit(e.lo); self.visit(e.hi); + encoded_at = Some(self.encoder.pos()); self.simd(0x03); self.encoder.raw(&e.indices); } } + if let (Some(pos), Some(map)) = (encoded_at, self.map.as_mut()) { + // Save the encoded_at position for the specified ExprId. + map.insert(id, pos); + } + self.id = old; } @@ -797,15 +841,23 @@ impl Emit<'_, '_> { ) as u32 } + fn map_encoder_pos(&mut self, id: ExprId) { + if let Some(map) = self.map.as_mut() { + map.insert(id, self.encoder.pos()); + } + } + fn visit_block(&mut self, e: &Block) { self.blocks.push(Block::new_id(self.id)); match e.kind { BlockKind::Block => { + self.map_encoder_pos(self.id); self.encoder.byte(0x02); // block self.block_type(&e.results); } BlockKind::Loop => { + self.map_encoder_pos(self.id); self.encoder.byte(0x03); // loop self.block_type(&e.results); } @@ -829,6 +881,7 @@ impl Emit<'_, '_> { fn visit_if_else(&mut self, e: &IfElse) { self.visit(e.condition); + self.map_encoder_pos(self.id); self.encoder.byte(0x04); // if let consequent = self.func.block(e.consequent); self.block_type(&consequent.results); @@ -848,6 +901,7 @@ impl Emit<'_, '_> { } self.visit(e.which); + self.map_encoder_pos(self.id); self.encoder.byte(0x0e); // br_table self.encoder.usize(e.blocks.len()); for b in e.blocks.iter() { diff --git a/src/module/functions/mod.rs b/src/module/functions/mod.rs index 0bf9dc50..e0256b83 100644 --- a/src/module/functions/mod.rs +++ b/src/module/functions/mod.rs @@ -489,8 +489,9 @@ impl Emit for ModuleFunctions { cx.indices.locals.reserve(bytes.len()); for (wasm, id, used_locals, local_indices, map) in bytes { + cx.encoder.usize(wasm.len()); let code_offset = cx.encoder.pos(); - cx.encoder.bytes(&wasm); + cx.encoder.raw(&wasm); if let Some(map) = map { append_code_offsets(&mut cx.code_transform, code_offset, map); }