Skip to content

Commit

Permalink
fix(napi-derive): unsound behavior while using reference and async to…
Browse files Browse the repository at this point in the history
…gether
  • Loading branch information
Xaeroxe committed Nov 21, 2022
1 parent 9189045 commit 618d0f8
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 28 deletions.
3 changes: 2 additions & 1 deletion crates/backend/src/ast.rs
Expand Up @@ -27,6 +27,7 @@ pub struct NapiFn {
pub enumerable: bool,
pub configurable: bool,
pub catch_unwind: bool,
pub unsafe_: bool,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -64,7 +65,7 @@ pub enum FnKind {
Setter,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FnSelf {
Value,
Ref,
Expand Down
171 changes: 144 additions & 27 deletions crates/backend/src/codegen/fn.rs
@@ -1,9 +1,10 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens;
use syn::spanned::Spanned;

use crate::{
codegen::{get_intermediate_ident, get_register_ident, js_mod_to_token_stream},
BindgenResult, CallbackArg, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens,
BindgenResult, CallbackArg, Diagnostic, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens,
};

impl TryToTokens for NapiFn {
Expand All @@ -12,13 +13,78 @@ impl TryToTokens for NapiFn {
let intermediate_ident = get_intermediate_ident(&name_str);
let args_len = self.args.len();

let (arg_conversions, arg_names) = self.gen_arg_conversions()?;
let ArgConversions {
arg_conversions,
args: arg_names,
refs,
mut_ref_spans,
unsafe_,
} = self.gen_arg_conversions()?;
// The JS engine can't properly track mutability in an async context, so refuse to compile
// code that tries to use async and mutability together without `unsafe` mark.
if self.is_async && !mut_ref_spans.is_empty() && !unsafe_ {
return Diagnostic::from_vec(
mut_ref_spans
.into_iter()
.map(|s| Diagnostic::span_error(s, "mutable reference is unsafe with async"))
.collect(),
);
}
if Some(FnSelf::MutRef) == self.fn_self && self.is_async {
return Err(Diagnostic::span_error(
self.name.span(),
"&mut self is incompatible with async napi methods",
));
}
let arg_ref_count = refs.len();
let receiver = self.gen_fn_receiver();
let receiver_ret_name = Ident::new("_ret", Span::call_site());
let ret = self.gen_fn_return(&receiver_ret_name);
let register = self.gen_fn_register();
let attrs = &self.attrs;

let build_ref_container = if self.is_async {
quote! {
struct NapiRefContainer([napi::sys::napi_ref; #arg_ref_count]);
impl NapiRefContainer {
fn drop(self, env: napi::sys::napi_env) {
for r in self.0.into_iter() {
assert_eq!(
unsafe { napi::sys::napi_delete_reference(env, r) },
napi::sys::Status::napi_ok,
"failed to delete napi ref"
);
}
}
}
unsafe impl Send for NapiRefContainer {}
unsafe impl Sync for NapiRefContainer {}
let _make_ref = |a: ::std::ptr::NonNull<napi::bindgen_prelude::sys::napi_value__>| {
let mut node_ref = ::std::mem::MaybeUninit::uninit();
assert_eq!(unsafe {
napi::bindgen_prelude::sys::napi_create_reference(env, a.as_ptr(), 1, node_ref.as_mut_ptr())
},
napi::bindgen_prelude::sys::Status::napi_ok,
"failed to create napi ref"
);
unsafe { node_ref.assume_init() }
};
let mut _args_array = [::std::ptr::null_mut::<napi::bindgen_prelude::sys::napi_ref__>(); #arg_ref_count];
let mut _arg_write_index = 0;

#(#refs)*

#[cfg(debug_assert)]
{
for a in &_args_array {
assert!(!a.is_null(), "failed to initialize napi ref");
}
}
let _args_ref = NapiRefContainer(_args_array);
}
} else {
quote! {}
};
let native_call = if !self.is_async {
quote! {
napi::bindgen_prelude::within_runtime_if_available(move || {
Expand All @@ -35,16 +101,26 @@ impl TryToTokens for NapiFn {
quote! { Ok(#receiver(#(#arg_names),*).await) }
};
quote! {
napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, |env, #receiver_ret_name| {
napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, move |env, #receiver_ret_name| {
_args_ref.drop(env);
#ret
})
}
};

let function_call_inner = quote! {
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#build_ref_container
#(#arg_conversions)*
#native_call
})
};

let function_call = if args_len == 0
&& self.fn_self.is_none()
&& self.kind != FnKind::Constructor
&& self.kind != FnKind::Factory
&& !self.is_async
{
quote! { #native_call }
} else if self.kind == FnKind::Constructor {
Expand All @@ -55,18 +131,10 @@ impl TryToTokens for NapiFn {
if inner.load(std::sync::atomic::Ordering::Relaxed) {
return std::ptr::null_mut();
}
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#(#arg_conversions)*
#native_call
})
#function_call_inner
}
} else {
quote! {
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#(#arg_conversions)*
#native_call
})
}
function_call_inner
};

let function_call = if self.catch_unwind {
Expand Down Expand Up @@ -109,20 +177,30 @@ impl TryToTokens for NapiFn {
}

impl NapiFn {
fn gen_arg_conversions(&self) -> BindgenResult<(Vec<TokenStream>, Vec<TokenStream>)> {
fn gen_arg_conversions(&self) -> BindgenResult<ArgConversions> {
let mut arg_conversions = vec![];
let mut args = vec![];
let mut refs = vec![];
let mut mut_ref_spans = vec![];
let make_ref = |input| {
quote! {
_args_array[_arg_write_index] = _make_ref(::std::ptr::NonNull::new(#input).expect("ref ptr was null"));
_arg_write_index += 1;
}
};

// fetch this
if let Some(parent) = &self.parent {
match self.fn_self {
Some(FnSelf::Ref) => {
refs.push(make_ref(quote! { cb.this }));
arg_conversions.push(quote! {
let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? };
let this: &#parent = Box::leak(Box::from_raw(this_ptr));
});
}
Some(FnSelf::MutRef) => {
refs.push(make_ref(quote! { cb.this }));
arg_conversions.push(quote! {
let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? };
let this: &mut #parent = Box::leak(Box::from_raw(this_ptr));
Expand Down Expand Up @@ -215,7 +293,9 @@ impl NapiFn {
}) = elem.as_ref()
{
if let Some(syn::PathSegment { ident, .. }) = segments.first() {
refs.push(make_ref(quote! { cb.this }));
let token = if mutability.is_some() {
mut_ref_spans.push(generic_type.span());
quote! { <#ident as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.this)? }
} else {
quote! { <#ident as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.this)? }
Expand All @@ -228,15 +308,21 @@ impl NapiFn {
}
}
}
args.push(
quote! { <napi::bindgen_prelude::This as napi::NapiValue>::from_raw_unchecked(env, cb.this) },
);
refs.push(make_ref(quote! { cb.this }));
args.push(quote! { <napi::bindgen_prelude::This as napi::NapiValue>::from_raw_unchecked(env, cb.this) });
skipped_arg_count += 1;
continue;
}
}
}
arg_conversions.push(self.gen_ty_arg_conversion(&ident, i, path));
let (arg_conversion, arg_type) = self.gen_ty_arg_conversion(&ident, i, path);
if NapiArgType::MutRef == arg_type {
mut_ref_spans.push(path.ty.span());
}
if arg_type.is_ref() {
refs.push(make_ref(quote! { cb.get_arg(#i) }));
}
arg_conversions.push(arg_conversion);
args.push(quote! { #ident });
}
}
Expand All @@ -247,17 +333,24 @@ impl NapiFn {
}
}

Ok((arg_conversions, args))
Ok(ArgConversions {
arg_conversions,
args,
refs,
mut_ref_spans,
unsafe_: self.unsafe_,
})
}

/// Returns a type conversion, and a boolean indicating whether this value needs to have a reference created to extend the lifetime
/// for async functions.
fn gen_ty_arg_conversion(
&self,
arg_name: &Ident,
index: usize,
path: &syn::PatType,
) -> TokenStream {
) -> (TokenStream, NapiArgType) {
let ty = &*path.ty;

let type_check = if self.return_if_invalid {
quote! {
if let Ok(maybe_promise) = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index)) {
Expand Down Expand Up @@ -285,28 +378,31 @@ impl NapiFn {
elem,
..
}) => {
quote! {
let q = quote! {
let #arg_name = {
#type_check
<#elem as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))?
};
}
};
(q, NapiArgType::MutRef)
}
syn::Type::Reference(syn::TypeReference { elem, .. }) => {
quote! {
let q = quote! {
let #arg_name = {
#type_check
<#elem as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))?
};
}
};
(q, NapiArgType::Ref)
}
_ => {
quote! {
let q = quote! {
let #arg_name = {
#type_check
<#ty as napi::bindgen_prelude::FromNapiValue>::from_napi_value(env, cb.get_arg(#index))?
};
}
};
(q, NapiArgType::Value)
}
}
}
Expand Down Expand Up @@ -482,3 +578,24 @@ impl NapiFn {
}
}
}

struct ArgConversions {
pub args: Vec<TokenStream>,
pub arg_conversions: Vec<TokenStream>,
pub refs: Vec<TokenStream>,
pub mut_ref_spans: Vec<Span>,
pub unsafe_: bool,
}

#[derive(Debug, PartialEq, Eq)]
enum NapiArgType {
Ref,
MutRef,
Value,
}

impl NapiArgType {
fn is_ref(&self) -> bool {
matches!(self, NapiArgType::Ref | NapiArgType::MutRef)
}
}
1 change: 1 addition & 0 deletions crates/macro/src/parser/mod.rs
Expand Up @@ -696,6 +696,7 @@ fn napi_fn_from_decl(
enumerable: opts.enumerable(),
configurable: opts.configurable(),
catch_unwind: opts.catch_unwind().is_some(),
unsafe_: sig.unsafety.is_some(),
}
})
}
Expand Down
1 change: 1 addition & 0 deletions examples/napi/__test__/typegen.spec.ts.md
Expand Up @@ -247,6 +247,7 @@ Generated by [AVA](https://avajs.dev).
name: string␊
constructor(name: string)␊
getCount(): number␊
getNameAsync(): Promise<string>␊
}␊
export type Blake2bHasher = Blake2BHasher␊
/** Smoking test for type generation */␊
Expand Down
Binary file modified examples/napi/__test__/typegen.spec.ts.snap
Binary file not shown.
5 changes: 5 additions & 0 deletions examples/napi/__test__/values.spec.ts
Expand Up @@ -206,6 +206,11 @@ test('class', (t) => {
})
})

test('async self in class', async (t) => {
const b = new Bird('foo')
t.is(await b.getNameAsync(), 'foo')
})

test('class factory', (t) => {
const duck = ClassWithFactory.withName('Default')
t.is(duck.name, 'Default')
Expand Down
1 change: 1 addition & 0 deletions examples/napi/index.d.ts
Expand Up @@ -237,6 +237,7 @@ export class Bird {
name: string
constructor(name: string)
getCount(): number
getNameAsync(): Promise<string>
}
export type Blake2bHasher = Blake2BHasher
/** Smoking test for type generation */
Expand Down
6 changes: 6 additions & 0 deletions examples/napi/src/class.rs
Expand Up @@ -123,6 +123,12 @@ impl Bird {
pub fn get_count(&self) -> u32 {
1234
}

#[napi]
pub async fn get_name_async(&self) -> &str {
tokio::time::sleep(std::time::Duration::new(1, 0)).await;
self.name.as_str()
}
}

/// Smoking test for type generation
Expand Down

0 comments on commit 618d0f8

Please sign in to comment.