diff --git a/crates/backend/src/ast.rs b/crates/backend/src/ast.rs index 544c95888c..2070fc7f38 100644 --- a/crates/backend/src/ast.rs +++ b/crates/backend/src/ast.rs @@ -27,6 +27,7 @@ pub struct NapiFn { pub enumerable: bool, pub configurable: bool, pub catch_unwind: bool, + pub unsafe_: bool, } #[derive(Debug, Clone)] @@ -64,7 +65,7 @@ pub enum FnKind { Setter, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum FnSelf { Value, Ref, diff --git a/crates/backend/src/codegen/fn.rs b/crates/backend/src/codegen/fn.rs index 43d7e274e0..00f268f0e5 100644 --- a/crates/backend/src/codegen/fn.rs +++ b/crates/backend/src/codegen/fn.rs @@ -1,9 +1,10 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::ToTokens; +use syn::spanned::Spanned; use crate::{ codegen::{get_intermediate_ident, get_register_ident, js_mod_to_token_stream}, - BindgenResult, CallbackArg, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens, + BindgenResult, CallbackArg, Diagnostic, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens, }; impl TryToTokens for NapiFn { @@ -12,13 +13,78 @@ impl TryToTokens for NapiFn { let intermediate_ident = get_intermediate_ident(&name_str); let args_len = self.args.len(); - let (arg_conversions, arg_names) = self.gen_arg_conversions()?; + let ArgConversions { + arg_conversions, + args: arg_names, + refs, + mut_ref_spans, + unsafe_, + } = self.gen_arg_conversions()?; + // The JS engine can't properly track mutability in an async context, so refuse to compile + // code that tries to use async and mutability together without `unsafe` mark. + if self.is_async && !mut_ref_spans.is_empty() && !unsafe_ { + return Diagnostic::from_vec( + mut_ref_spans + .into_iter() + .map(|s| Diagnostic::span_error(s, "mutable reference is unsafe with async")) + .collect(), + ); + } + if Some(FnSelf::MutRef) == self.fn_self && self.is_async { + return Err(Diagnostic::span_error( + self.name.span(), + "&mut self is incompatible with async napi methods", + )); + } + let arg_ref_count = refs.len(); let receiver = self.gen_fn_receiver(); let receiver_ret_name = Ident::new("_ret", Span::call_site()); let ret = self.gen_fn_return(&receiver_ret_name); let register = self.gen_fn_register(); let attrs = &self.attrs; + let build_ref_container = if self.is_async { + quote! { + struct NapiRefContainer([napi::sys::napi_ref; #arg_ref_count]); + impl NapiRefContainer { + fn drop(self, env: napi::sys::napi_env) { + for r in self.0.into_iter() { + assert_eq!( + unsafe { napi::sys::napi_delete_reference(env, r) }, + napi::sys::Status::napi_ok, + "failed to delete napi ref" + ); + } + } + } + unsafe impl Send for NapiRefContainer {} + unsafe impl Sync for NapiRefContainer {} + let _make_ref = |a: ::std::ptr::NonNull| { + let mut node_ref = ::std::mem::MaybeUninit::uninit(); + assert_eq!(unsafe { + napi::bindgen_prelude::sys::napi_create_reference(env, a.as_ptr(), 1, node_ref.as_mut_ptr()) + }, + napi::bindgen_prelude::sys::Status::napi_ok, + "failed to create napi ref" + ); + unsafe { node_ref.assume_init() } + }; + let mut _args_array = [::std::ptr::null_mut::(); #arg_ref_count]; + let mut _arg_write_index = 0; + + #(#refs)* + + #[cfg(debug_assert)] + { + for a in &_args_array { + assert!(!a.is_null(), "failed to initialize napi ref"); + } + } + let _args_ref = NapiRefContainer(_args_array); + } + } else { + quote! {} + }; let native_call = if !self.is_async { quote! { napi::bindgen_prelude::within_runtime_if_available(move || { @@ -35,16 +101,26 @@ impl TryToTokens for NapiFn { quote! { Ok(#receiver(#(#arg_names),*).await) } }; quote! { - napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, |env, #receiver_ret_name| { + napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, move |env, #receiver_ret_name| { + _args_ref.drop(env); #ret }) } }; + let function_call_inner = quote! { + napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| { + #build_ref_container + #(#arg_conversions)* + #native_call + }) + }; + let function_call = if args_len == 0 && self.fn_self.is_none() && self.kind != FnKind::Constructor && self.kind != FnKind::Factory + && !self.is_async { quote! { #native_call } } else if self.kind == FnKind::Constructor { @@ -55,18 +131,10 @@ impl TryToTokens for NapiFn { if inner.load(std::sync::atomic::Ordering::Relaxed) { return std::ptr::null_mut(); } - napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| { - #(#arg_conversions)* - #native_call - }) + #function_call_inner } } else { - quote! { - napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| { - #(#arg_conversions)* - #native_call - }) - } + function_call_inner }; let function_call = if self.catch_unwind { @@ -109,20 +177,30 @@ impl TryToTokens for NapiFn { } impl NapiFn { - fn gen_arg_conversions(&self) -> BindgenResult<(Vec, Vec)> { + fn gen_arg_conversions(&self) -> BindgenResult { let mut arg_conversions = vec![]; let mut args = vec![]; + let mut refs = vec![]; + let mut mut_ref_spans = vec![]; + let make_ref = |input| { + quote! { + _args_array[_arg_write_index] = _make_ref(::std::ptr::NonNull::new(#input).expect("ref ptr was null")); + _arg_write_index += 1; + } + }; // fetch this if let Some(parent) = &self.parent { match self.fn_self { Some(FnSelf::Ref) => { + refs.push(make_ref(quote! { cb.this })); arg_conversions.push(quote! { let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? }; let this: &#parent = Box::leak(Box::from_raw(this_ptr)); }); } Some(FnSelf::MutRef) => { + refs.push(make_ref(quote! { cb.this })); arg_conversions.push(quote! { let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? }; let this: &mut #parent = Box::leak(Box::from_raw(this_ptr)); @@ -215,7 +293,9 @@ impl NapiFn { }) = elem.as_ref() { if let Some(syn::PathSegment { ident, .. }) = segments.first() { + refs.push(make_ref(quote! { cb.this })); let token = if mutability.is_some() { + mut_ref_spans.push(generic_type.span()); quote! { <#ident as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.this)? } } else { quote! { <#ident as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.this)? } @@ -228,15 +308,21 @@ impl NapiFn { } } } - args.push( - quote! { ::from_raw_unchecked(env, cb.this) }, - ); + refs.push(make_ref(quote! { cb.this })); + args.push(quote! { ::from_raw_unchecked(env, cb.this) }); skipped_arg_count += 1; continue; } } } - arg_conversions.push(self.gen_ty_arg_conversion(&ident, i, path)); + let (arg_conversion, arg_type) = self.gen_ty_arg_conversion(&ident, i, path); + if NapiArgType::MutRef == arg_type { + mut_ref_spans.push(path.ty.span()); + } + if arg_type.is_ref() { + refs.push(make_ref(quote! { cb.get_arg(#i) })); + } + arg_conversions.push(arg_conversion); args.push(quote! { #ident }); } } @@ -247,17 +333,24 @@ impl NapiFn { } } - Ok((arg_conversions, args)) + Ok(ArgConversions { + arg_conversions, + args, + refs, + mut_ref_spans, + unsafe_: self.unsafe_, + }) } + /// Returns a type conversion, and a boolean indicating whether this value needs to have a reference created to extend the lifetime + /// for async functions. fn gen_ty_arg_conversion( &self, arg_name: &Ident, index: usize, path: &syn::PatType, - ) -> TokenStream { + ) -> (TokenStream, NapiArgType) { let ty = &*path.ty; - let type_check = if self.return_if_invalid { quote! { if let Ok(maybe_promise) = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index)) { @@ -285,28 +378,31 @@ impl NapiFn { elem, .. }) => { - quote! { + let q = quote! { let #arg_name = { #type_check <#elem as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))? }; - } + }; + (q, NapiArgType::MutRef) } syn::Type::Reference(syn::TypeReference { elem, .. }) => { - quote! { + let q = quote! { let #arg_name = { #type_check <#elem as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))? }; - } + }; + (q, NapiArgType::Ref) } _ => { - quote! { + let q = quote! { let #arg_name = { #type_check <#ty as napi::bindgen_prelude::FromNapiValue>::from_napi_value(env, cb.get_arg(#index))? }; - } + }; + (q, NapiArgType::Value) } } } @@ -482,3 +578,24 @@ impl NapiFn { } } } + +struct ArgConversions { + pub args: Vec, + pub arg_conversions: Vec, + pub refs: Vec, + pub mut_ref_spans: Vec, + pub unsafe_: bool, +} + +#[derive(Debug, PartialEq, Eq)] +enum NapiArgType { + Ref, + MutRef, + Value, +} + +impl NapiArgType { + fn is_ref(&self) -> bool { + matches!(self, NapiArgType::Ref | NapiArgType::MutRef) + } +} diff --git a/crates/macro/src/parser/mod.rs b/crates/macro/src/parser/mod.rs index 2c068d1b6f..cb659923d2 100644 --- a/crates/macro/src/parser/mod.rs +++ b/crates/macro/src/parser/mod.rs @@ -696,6 +696,7 @@ fn napi_fn_from_decl( enumerable: opts.enumerable(), configurable: opts.configurable(), catch_unwind: opts.catch_unwind().is_some(), + unsafe_: sig.unsafety.is_some(), } }) } diff --git a/examples/napi/__test__/typegen.spec.ts.md b/examples/napi/__test__/typegen.spec.ts.md index ce2521b9ea..71a27146a4 100644 --- a/examples/napi/__test__/typegen.spec.ts.md +++ b/examples/napi/__test__/typegen.spec.ts.md @@ -247,6 +247,7 @@ Generated by [AVA](https://avajs.dev). name: string␊ constructor(name: string)␊ getCount(): number␊ + getNameAsync(): Promise␊ }␊ export type Blake2bHasher = Blake2BHasher␊ /** Smoking test for type generation */␊ diff --git a/examples/napi/__test__/typegen.spec.ts.snap b/examples/napi/__test__/typegen.spec.ts.snap index 88f7a84c7d..2bcb9a7bd4 100644 Binary files a/examples/napi/__test__/typegen.spec.ts.snap and b/examples/napi/__test__/typegen.spec.ts.snap differ diff --git a/examples/napi/__test__/values.spec.ts b/examples/napi/__test__/values.spec.ts index 6dcbedf3bf..4465ad986f 100644 --- a/examples/napi/__test__/values.spec.ts +++ b/examples/napi/__test__/values.spec.ts @@ -206,6 +206,11 @@ test('class', (t) => { }) }) +test('async self in class', async (t) => { + const b = new Bird('foo') + t.is(await b.getNameAsync(), 'foo') +}) + test('class factory', (t) => { const duck = ClassWithFactory.withName('Default') t.is(duck.name, 'Default') diff --git a/examples/napi/index.d.ts b/examples/napi/index.d.ts index 67d8b6c70b..dacaa6b9b2 100644 --- a/examples/napi/index.d.ts +++ b/examples/napi/index.d.ts @@ -237,6 +237,7 @@ export class Bird { name: string constructor(name: string) getCount(): number + getNameAsync(): Promise } export type Blake2bHasher = Blake2BHasher /** Smoking test for type generation */ diff --git a/examples/napi/src/class.rs b/examples/napi/src/class.rs index 72c3a5ff3f..29a2b4b37f 100644 --- a/examples/napi/src/class.rs +++ b/examples/napi/src/class.rs @@ -123,6 +123,12 @@ impl Bird { pub fn get_count(&self) -> u32 { 1234 } + + #[napi] + pub async fn get_name_async(&self) -> &str { + tokio::time::sleep(std::time::Duration::new(1, 0)).await; + self.name.as_str() + } } /// Smoking test for type generation