Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track original wasm instructions location #71

Merged
merged 1 commit into from Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/macro/src/lib.rs
Expand Up @@ -509,7 +509,7 @@ fn create_visit(variants: &[WalrusVariant]) -> impl quote::ToTokens {

/// Visit `Instr`.
#[inline]
fn visit_instr(&mut self, instr: &'instr Instr) {
fn visit_instr(&mut self, instr: &'instr Instr, instr_loc: &'instr InstrLocId) {
// ...
}

Expand Down Expand Up @@ -593,7 +593,7 @@ fn create_visit(variants: &[WalrusVariant]) -> impl quote::ToTokens {

/// Visit `Instr`.
#[inline]
fn visit_instr_mut(&mut self, instr: &mut Instr) {
fn visit_instr_mut(&mut self, instr: &mut Instr, instr_loc: &mut InstrLocId) {
// ...
}

Expand Down
65 changes: 64 additions & 1 deletion crates/tests/tests/custom_sections.rs
@@ -1,7 +1,7 @@
//! Tests for working with custom sections that `walrus` doesn't know about.

use std::borrow::Cow;
use walrus::{CustomSection, IdsToIndices, Module, ModuleConfig};
use walrus::{CodeTransform, CustomSection, IdsToIndices, Module, ModuleConfig, ValType};

#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct HelloCustomSection(String);
Expand Down Expand Up @@ -63,3 +63,66 @@ fn round_trip_unkown_custom_sections() {
let new_wasm = module.emit_wasm();
assert_eq!(wasm, new_wasm);
}

// Insert a `(drop (i32.const 0))` at the start of the function and assert that
// all instructions are pushed down by the size of a `(drop (i32.const 0))`,
// which is 3.
#[test]
fn smoke_test_code_transform() {
use std::sync::atomic::{AtomicUsize, Ordering};

static APPLIED_CODE_TRANSFORM: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
struct CheckCodeTransform;
impl CustomSection for CheckCodeTransform {
fn name(&self) -> &str {
"check-code-transform"
}

fn data(&self, _: &IdsToIndices) -> Cow<[u8]> {
vec![].into()
}

fn apply_code_transform(&mut self, transform: &CodeTransform) {
APPLIED_CODE_TRANSFORM.store(1, Ordering::SeqCst);
assert!(!transform.is_empty());
for (input_offset, output_offset) in transform.iter().cloned() {
assert_eq!(input_offset.data() as usize + 3, output_offset);
}
}
}

let mut config = ModuleConfig::new();
config.generate_producers_section(false);

let wasm = {
let mut module = Module::with_config(config.clone());

let mut builder = walrus::FunctionBuilder::new(&mut module.types, &[], &[ValType::I32]);
builder.func_body().i32_const(1337);
let locals = vec![];
let f_id = builder.finish(locals, &mut module.funcs);

module.exports.add("f", f_id);

module.emit_wasm()
};

config.preserve_code_transform(true);

let mut module = config.parse(&wasm).unwrap();
module.customs.add(CheckCodeTransform);

for (_id, f) in module.funcs.iter_local_mut() {
let builder = f.builder_mut();
builder.func_body().const_at(0, walrus::ir::Value::I32(0));
builder.func_body().drop_at(1);
}

// Emit the new, transformed wasm. This should trigger the
// `apply_code_transform` method to be called.
let _wasm = module.emit_wasm();

assert_eq!(APPLIED_CODE_TRANSFORM.load(Ordering::SeqCst), 1);
}
4 changes: 2 additions & 2 deletions crates/tests/tests/spec-tests.rs
Expand Up @@ -100,14 +100,14 @@ fn run(wast: &Path) -> Result<(), anyhow::Error> {
}
cmd => {
let wasm = fs::read(&path)?;
let wasm = config
let mut wasm = config
.parse(&wasm)
.context(format!("error parsing wasm (line {})", line))?;
let wasm1 = wasm.emit_wasm();
fs::write(&path, &wasm1)?;
let wasm2 = config
.parse(&wasm1)
.map(|m| m.emit_wasm())
.map(|mut m| m.emit_wasm())
.context(format!("error re-parsing wasm (line {})", line))?;
if wasm1 != wasm2 {
panic!("wasm module at line {} isn't deterministic", line);
Expand Down
8 changes: 4 additions & 4 deletions examples/round-trip.rs
Expand Up @@ -2,10 +2,10 @@

fn main() -> anyhow::Result<()> {
env_logger::init();
let a = std::env::args().nth(1).ok_or_else(|| {
anyhow::anyhow!("must provide the input wasm file as the first argument")
})?;
let m = walrus::Module::from_file(&a)?;
let a = std::env::args()
.nth(1)
.ok_or_else(|| anyhow::anyhow!("must provide the input wasm file as the first argument"))?;
let mut m = walrus::Module::from_file(&a)?;
let wasm = m.emit_wasm();
if let Some(destination) = std::env::args().nth(2) {
std::fs::write(destination, wasm)?;
Expand Down
4 changes: 2 additions & 2 deletions fuzz/fuzz_targets/raw.rs
Expand Up @@ -4,12 +4,12 @@
extern crate libfuzzer_sys;

fuzz_target!(|data: &[u8]| {
let module = match walrus::Module::from_buffer(data) {
let mut module = match walrus::Module::from_buffer(data) {
Ok(m) => m,
Err(_) => return,
};
let serialized = module.emit_wasm();
let module =
let mut module =
walrus::Module::from_buffer(&serialized).expect("we should only emit valid Wasm data");
let reserialized = module.emit_wasm();
assert_eq!(
Expand Down
4 changes: 2 additions & 2 deletions src/dot.rs
Expand Up @@ -385,13 +385,13 @@ impl Dot for LocalFunction {

impl DotNode for InstrSeq {
fn fields(&self, fields: &mut impl FieldAggregator) {
for (i, instr) in self.instrs.iter().enumerate() {
for (i, (instr, _)) in self.instrs.iter().enumerate() {
fields.add_field_with_port(&i.to_string(), &format!("{:?}", instr));
}
}

fn edges(&self, edges: &mut impl EdgeAggregator) {
for (i, instr) in self.instrs.iter().enumerate() {
for (i, (instr, _)) in self.instrs.iter().enumerate() {
let port = i.to_string();
instr.visit(&mut DotVisitor { port, edges });
}
Expand Down
3 changes: 2 additions & 1 deletion src/emit.rs
Expand Up @@ -5,8 +5,8 @@
use crate::encode::{Encoder, MAX_U32_LENGTH};
use crate::ir::Local;
use crate::map::{IdHashMap, IdHashSet};
use crate::{CodeTransform, Global, GlobalId, Memory, MemoryId, Module, Table, TableId};
use crate::{Data, DataId, Element, ElementId, Function, FunctionId};
use crate::{Global, GlobalId, Memory, MemoryId, Module, Table, TableId};
use crate::{Type, TypeId};
use std::ops::{Deref, DerefMut};

Expand All @@ -15,6 +15,7 @@ pub struct EmitContext<'a> {
pub indices: &'a mut IdsToIndices,
pub encoder: Encoder<'a>,
pub locals: IdHashMap<Function, IdHashSet<Local>>,
pub code_transform: CodeTransform,
}

pub struct SubContext<'a, 'cx> {
Expand Down
10 changes: 6 additions & 4 deletions src/function_builder.rs
Expand Up @@ -164,19 +164,21 @@ impl InstrSeqBuilder<'_> {
}

/// Get this instruction sequence's instructions.
pub fn instrs(&self) -> &[Instr] {
pub fn instrs(&self) -> &[(Instr, InstrLocId)] {
&self.builder.arena[self.id]
}

/// Get this instruction sequence's instructions mutably.
pub fn instrs_mut(&mut self) -> &mut Vec<Instr> {
pub fn instrs_mut(&mut self) -> &mut Vec<(Instr, InstrLocId)> {
&mut self.builder.arena[self.id].instrs
}

/// Pushes a new instruction onto this builder's sequence.
#[inline]
pub fn instr(&mut self, instr: impl Into<Instr>) -> &mut Self {
self.builder.arena[self.id].instrs.push(instr.into());
self.builder.arena[self.id]
.instrs
.push((instr.into(), Default::default()));
self
}

Expand All @@ -189,7 +191,7 @@ impl InstrSeqBuilder<'_> {
pub fn instr_at(&mut self, position: usize, instr: impl Into<Instr>) -> &mut Self {
self.builder.arena[self.id]
.instrs
.insert(position, instr.into());
.insert(position, (instr.into(), Default::default()));
self
}

Expand Down
40 changes: 36 additions & 4 deletions src/ir/mod.rs
Expand Up @@ -119,6 +119,38 @@ impl From<TypeId> for InstrSeqType {
}
}

/// A symbolic original wasm operator source location.
#[derive(Debug, Copy, Clone)]
pub struct InstrLocId(u32);

const DEFAULT_INSTR_LOC_ID: u32 = 0xffff_ffff;

impl InstrLocId {
/// Create `InstrLocId` from provided data. Normaly the data is
/// wasm bytecode offset. (0xffff_ffff is reserved for default value).
pub fn new(data: u32) -> Self {
assert!(data != DEFAULT_INSTR_LOC_ID);
InstrLocId(data)
}

/// Check if default value.
pub fn is_default(&self) -> bool {
self.0 == DEFAULT_INSTR_LOC_ID
}

/// The data
pub fn data(&self) -> u32 {
assert!(self.0 != DEFAULT_INSTR_LOC_ID);
self.0
}
}

impl Default for InstrLocId {
fn default() -> Self {
InstrLocId(DEFAULT_INSTR_LOC_ID)
}
}

/// A sequence of instructions.
#[derive(Debug)]
pub struct InstrSeq {
Expand All @@ -130,21 +162,21 @@ pub struct InstrSeq {
pub ty: InstrSeqType,

/// The instructions that make up the body of this block.
pub instrs: Vec<Instr>,
pub instrs: Vec<(Instr, InstrLocId)>,
}

impl Deref for InstrSeq {
type Target = Vec<Instr>;
type Target = Vec<(Instr, InstrLocId)>;

#[inline]
fn deref(&self) -> &Vec<Instr> {
fn deref(&self) -> &Vec<(Instr, InstrLocId)> {
&self.instrs
}
}

impl DerefMut for InstrSeq {
#[inline]
fn deref_mut(&mut self) -> &mut Vec<Instr> {
fn deref_mut(&mut self) -> &mut Vec<(Instr, InstrLocId)> {
&mut self.instrs
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/ir/traversals.rs
Expand Up @@ -78,10 +78,10 @@ pub fn dfs_in_order<'instr>(
seq.visit(visitor);
}

'traversing_instrs: for (index, instr) in seq.instrs.iter().enumerate().skip(index) {
'traversing_instrs: for (index, (instr, loc)) in seq.instrs.iter().enumerate().skip(index) {
// Visit this instruction.
log::trace!("dfs_in_order: visit_instr({:?})", instr);
visitor.visit_instr(instr);
visitor.visit_instr(instr, loc);

// Visit every other resource that this instruction references,
// e.g. `MemoryId`s, `FunctionId`s and all that.
Expand Down Expand Up @@ -192,8 +192,8 @@ pub fn dfs_pre_order_mut(
visitor.start_instr_seq_mut(seq);
seq.visit_mut(visitor);

for instr in &mut seq.instrs {
visitor.visit_instr_mut(instr);
for (instr, loc) in &mut seq.instrs {
visitor.visit_instr_mut(instr, loc);
instr.visit_mut(visitor);

match instr {
Expand Down
29 changes: 29 additions & 0 deletions src/module/config.rs
@@ -1,4 +1,5 @@
use crate::error::Result;
use crate::ir::InstrLocId;
use crate::module::Module;
use crate::parse::IndicesToIds;
use std::fmt;
Expand All @@ -13,8 +14,10 @@ pub struct ModuleConfig {
pub(crate) skip_strict_validate: bool,
pub(crate) skip_producers_section: bool,
pub(crate) skip_name_section: bool,
pub(crate) preserve_code_transform: bool,
pub(crate) on_parse:
Option<Box<dyn Fn(&mut Module, &IndicesToIds) -> Result<()> + Sync + Send + 'static>>,
pub(crate) on_instr_loc: Option<Box<dyn Fn(&usize) -> InstrLocId + Sync + Send + 'static>>,
}

impl Clone for ModuleConfig {
Expand All @@ -28,9 +31,11 @@ impl Clone for ModuleConfig {
skip_strict_validate: self.skip_strict_validate,
skip_producers_section: self.skip_producers_section,
skip_name_section: self.skip_name_section,
preserve_code_transform: self.preserve_code_transform,

// ... and this is left empty.
on_parse: None,
on_instr_loc: None,
}
}
}
Expand All @@ -46,7 +51,9 @@ impl fmt::Debug for ModuleConfig {
ref skip_strict_validate,
ref skip_producers_section,
ref skip_name_section,
ref preserve_code_transform,
ref on_parse,
ref on_instr_loc,
} = self;

f.debug_struct("ModuleConfig")
Expand All @@ -59,7 +66,9 @@ impl fmt::Debug for ModuleConfig {
.field("skip_strict_validate", skip_strict_validate)
.field("skip_producers_section", skip_producers_section)
.field("skip_name_section", skip_name_section)
.field("preserve_code_transform", preserve_code_transform)
.field("on_parse", &on_parse.as_ref().map(|_| ".."))
.field("on_instr_loc", &on_instr_loc.as_ref().map(|_| ".."))
.finish()
}
}
Expand Down Expand Up @@ -171,6 +180,26 @@ impl ModuleConfig {
self
}

/// Provide a function that is invoked on source location ID step.
///
/// Note that cloning a `ModuleConfig` will result in a config that does not
/// have an `on_instr_loc` function, even if the original did.
pub fn on_instr_loc<F>(&mut self, f: F) -> &mut ModuleConfig
where
F: Fn(&usize) -> InstrLocId + Send + Sync + 'static,
{
self.on_instr_loc = Some(Box::new(f) as _);
self
}

/// Sets a flag to whether code transform is preverved during parsing.
///
/// By default this flag is `false`.
pub fn preserve_code_transform(&mut self, preserve: bool) -> &mut ModuleConfig {
self.preserve_code_transform = preserve;
self
}

/// Parses an in-memory WebAssembly file into a `Module` using this
/// configuration.
pub fn parse(&self, wasm: &[u8]) -> Result<Module> {
Expand Down