diff --git a/serial_test_derive/src/lib.rs b/serial_test_derive/src/lib.rs index 47dfcdd..c8773e2 100644 --- a/serial_test_derive/src/lib.rs +++ b/serial_test_derive/src/lib.rs @@ -377,6 +377,7 @@ where if asyncness.is_some() && cfg!(not(feature = "async")) { panic!("async testing attempted with async feature disabled in serial_test!"); } + let vis = ast.vis; let name = ast.sig.ident; let return_type = match ast.sig.output { syn::ReturnType::Default => None, @@ -414,7 +415,7 @@ where quote! { #(#attrs) * - async fn #name () -> #ret { + #vis async fn #name () -> #ret { serial_test::#fnname(#(#args ),*, || async #block ).await; } } @@ -424,7 +425,7 @@ where quote! { #(#attrs) * - fn #name () -> #ret { + #vis fn #name () -> #ret { serial_test::#fnname(#(#args ),*, || #block ) } } @@ -437,7 +438,7 @@ where quote! { #(#attrs) * - async fn #name () { + #vis async fn #name () { serial_test::#fnname(#(#args ),*, || async #block ).await; } } @@ -447,7 +448,7 @@ where quote! { #(#attrs) * - fn #name () { + #vis fn #name () { serial_test::#fnname(#(#args ),*, || #block ); } } @@ -502,6 +503,23 @@ mod tests { assert_eq!(format!("{}", compare), format!("{}", stream)); } + #[test] + fn test_serial_with_pub() { + let attrs = proc_macro2::TokenStream::new(); + let input = quote! { + #[test] + pub fn foo() {} + }; + let stream = local_serial_core(attrs.into(), input); + let compare = quote! { + #[test] + pub fn foo () { + serial_test::local_serial_core("", :: std :: option :: Option :: None, || {} ); + } + }; + assert_eq!(format!("{}", compare), format!("{}", stream)); + } + #[test] fn test_serial_with_timeout() { let attrs = vec![