From 6a3e1e7339fa2a7e87c3769b8dc940fd62b46578 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sun, 31 Oct 2021 09:16:13 +0000 Subject: [PATCH] macros: clean up protocol argument extraction a bit --- pyo3-macros-backend/src/params.rs | 37 +----- pyo3-macros-backend/src/pymethod.rs | 186 +++++++++++++--------------- pyo3-macros-backend/src/utils.rs | 22 ++++ 3 files changed, 113 insertions(+), 132 deletions(-) diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index de39f1bce78..1bbbb4727d8 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -4,7 +4,7 @@ use crate::{ attributes::FromPyWithAttribute, method::{FnArg, FnSpec}, pyfunction::Argument, - utils::unwrap_ty_group, + utils::{remove_lifetime, replace_self, unwrap_ty_group}, }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; @@ -267,7 +267,11 @@ fn impl_arg_param( }; return if let syn::Type::Reference(tref) = unwrap_ty_group(arg.optional.unwrap_or(ty)) { - let (tref, mut_) = preprocess_tref(tref, self_); + let mut tref = remove_lifetime(tref); + if let Some(cls) = self_ { + replace_self(&mut tref.elem, cls); + } + let mut_ = tref.mutability; let (target_ty, borrow_tmp) = if arg.optional.is_some() { // Get Option<&T> from Option> ( @@ -295,33 +299,4 @@ fn impl_arg_param( let #arg_name = #arg_value_or_default; }) }; - - /// Replace `Self`, remove lifetime and get mutability from the type - fn preprocess_tref( - tref: &syn::TypeReference, - self_: Option<&syn::Type>, - ) -> (syn::TypeReference, Option) { - let mut tref = tref.to_owned(); - if let Some(syn::Type::Path(tpath)) = self_ { - replace_self(&mut tref, &tpath.path); - } - tref.lifetime = None; - let mut_ = tref.mutability; - (tref, mut_) - } - - /// Replace `Self` with the exact type name since it is used out of the impl block - fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { - match &mut *tref.elem { - syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), - syn::Type::Path(tpath) => { - if let Some(ident) = tpath.path.get_ident() { - if ident == "Self" { - tpath.path = self_path.to_owned(); - } - } - } - _ => {} - } - } } diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index e9bcde2e8eb..39758a909c0 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -4,7 +4,9 @@ use std::borrow::Cow; use crate::attributes::NameAttribute; use crate::method::{CallingConvention, ExtractErrorMode}; -use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc}; +use crate::utils::{ + ensure_not_async_fn, remove_lifetime, replace_self, unwrap_ty_group, PythonDoc, +}; use crate::{deprecations::Deprecations, utils}; use crate::{ method::{FnArg, FnSpec, FnType, SelfType}, @@ -424,7 +426,7 @@ const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc") )); const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc") .extract_error_mode(ExtractErrorMode::NotImplemented) - .arguments(&[Ty::ObjectOrNotImplemented, Ty::CompareOp]); + .arguments(&[Ty::Object, Ty::CompareOp]); const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]); const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); @@ -452,55 +454,55 @@ const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc"); const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object, Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc") - .arguments(&[Ty::ObjectOrNotImplemented]) + .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); @@ -548,7 +550,6 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { #[derive(Clone, Copy)] enum Ty { Object, - ObjectOrNotImplemented, NonNullObject, CompareOp, Int, @@ -560,7 +561,7 @@ enum Ty { impl Ty { fn ffi_type(self) -> TokenStream { match self { - Ty::Object | Ty::ObjectOrNotImplemented => quote! { *mut ::pyo3::ffi::PyObject }, + Ty::Object => quote! { *mut ::pyo3::ffi::PyObject }, Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, @@ -574,95 +575,82 @@ impl Ty { cls: &syn::Type, py: &syn::Ident, ident: &syn::Ident, - target: &syn::Type, + arg: &FnArg, + extract_error_mode: ExtractErrorMode, ) -> TokenStream { match self { Ty::Object => { - let extract = extract_from_any(cls, target, ident); - quote! { - let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); - #extract - } - } - Ty::ObjectOrNotImplemented => { - let extract = if let syn::Type::Reference(tref) = unwrap_ty_group(target) { - let (tref, mut_) = preprocess_tref(tref, cls); + let extract = handle_error( + extract_error_mode, + py, quote! { - let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = match #ident.extract() { - ::std::result::Result::Ok(#ident) => #ident, - ::std::result::Result::Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), - }; - let #ident = &#mut_ *#ident; - } - } else { - quote! { - let #ident = match #ident.extract() { - ::std::result::Result::Ok(#ident) => #ident, - ::std::result::Result::Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), - }; - } - }; - quote! { - let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); - #extract - } + #py.from_borrowed_ptr::<::pyo3::PyAny>(#ident).extract() + }, + ); + extract_object(cls, arg.ty, ident, extract) } Ty::NonNullObject => { - let extract = extract_from_any(cls, target, ident); + let extract = handle_error( + extract_error_mode, + py, + quote! { + #py.from_borrowed_ptr::<::pyo3::PyAny>(#ident.as_ptr()).extract() + }, + ); + extract_object(cls, arg.ty, ident, extract) + } + Ty::CompareOp => { + let extract = handle_error( + extract_error_mode, + py, + quote! { + ::pyo3::class::basic::CompareOp::from_raw(#ident) + .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator")) + }, + ); quote! { - let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident.as_ptr()); - #extract + let #ident = #extract; } } - Ty::CompareOp => quote! { - let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) - .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; - }, Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(), } } } -fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -> TokenStream { - return if let syn::Type::Reference(tref) = unwrap_ty_group(target) { - let (tref, mut_) = preprocess_tref(tref, self_); +fn handle_error( + extract_error_mode: ExtractErrorMode, + py: &syn::Ident, + extract: TokenStream, +) -> TokenStream { + match extract_error_mode { + ExtractErrorMode::Raise => quote! { #extract? }, + ExtractErrorMode::NotImplemented => quote! { + match #extract { + ::std::result::Result::Ok(value) => value, + ::std::result::Result::Err(_) => { return ::pyo3::callback::convert(#py, #py.NotImplemented()); }, + } + }, + } +} + +fn extract_object( + cls: &syn::Type, + target: &syn::Type, + ident: &syn::Ident, + extract: TokenStream, +) -> TokenStream { + if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let mut tref = remove_lifetime(tref); + replace_self(&mut tref.elem, cls); + let mut_ = tref.mutability; quote! { - let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #ident.extract()?; + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #extract; let #ident = &#mut_ *#ident; } } else { quote! { - let #ident = #ident.extract()?; - } - }; -} - -/// Replace `Self`, remove lifetime and get mutability from the type -fn preprocess_tref( - tref: &syn::TypeReference, - self_: &syn::Type, -) -> (syn::TypeReference, Option) { - let mut tref = tref.to_owned(); - if let syn::Type::Path(tpath) = self_ { - replace_self(&mut tref, &tpath.path); - } - tref.lifetime = None; - let mut_ = tref.mutability; - (tref, mut_) -} - -/// Replace `Self` with the exact type name since it is used out of the impl block -fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { - match &mut *tref.elem { - syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), - syn::Type::Path(tpath) => { - if let Some(ident) = tpath.path.get_ident() { - if ident == "Self" { - tpath.path = self_path.to_owned(); - } - } + let #ident = #extract; } - _ => {} } } @@ -800,7 +788,8 @@ fn generate_method_body( ) -> Result { let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode); let rust_name = spec.name; - let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?; + let (arg_idents, conversions) = + extract_proto_arguments(cls, py, &spec.args, arguments, extract_error_mode)?; let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; let body = if let Some(return_mode) = return_mode { return_mode.return_call_output(py, call) @@ -883,7 +872,7 @@ const __DELITEM__: SlotFragmentDef = SlotFragmentDef::new("__delitem__", &[Ty::O macro_rules! binary_num_slot_fragment_def { ($ident:ident, $name:literal) => { - const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::ObjectOrNotImplemented]) + const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .ret_ty(Ty::Object); }; @@ -916,18 +905,12 @@ binary_num_slot_fragment_def!(__RXOR__, "__rxor__"); binary_num_slot_fragment_def!(__OR__, "__or__"); binary_num_slot_fragment_def!(__ROR__, "__ror__"); -const __POW__: SlotFragmentDef = SlotFragmentDef::new( - "__pow__", - &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], -) -.extract_error_mode(ExtractErrorMode::NotImplemented) -.ret_ty(Ty::Object); -const __RPOW__: SlotFragmentDef = SlotFragmentDef::new( - "__rpow__", - &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], -) -.extract_error_mode(ExtractErrorMode::NotImplemented) -.ret_ty(Ty::Object); +const __POW__: SlotFragmentDef = SlotFragmentDef::new("__pow__", &[Ty::Object, Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object, Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { match method_name { @@ -974,6 +957,7 @@ fn extract_proto_arguments( py: &syn::Ident, method_args: &[FnArg], proto_args: &[Ty], + extract_error_mode: ExtractErrorMode, ) -> Result<(Vec, TokenStream)> { let mut arg_idents = Vec::with_capacity(method_args.len()); let mut non_python_args = 0; @@ -987,7 +971,7 @@ fn extract_proto_arguments( let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); let conversions = proto_args.get(non_python_args) .ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))? - .extract(cls, py, &ident, arg.ty); + .extract(cls, py, &ident, arg, extract_error_mode); non_python_args += 1; args_conversions.push(conversions); arg_idents.push(ident); diff --git a/pyo3-macros-backend/src/utils.rs b/pyo3-macros-backend/src/utils.rs index 33ba2c04f0f..b0353cdc16a 100644 --- a/pyo3-macros-backend/src/utils.rs +++ b/pyo3-macros-backend/src/utils.rs @@ -173,3 +173,25 @@ pub fn unwrap_ty_group(mut ty: &syn::Type) -> &syn::Type { } ty } + +/// Remove lifetime from reference +pub(crate) fn remove_lifetime(tref: &syn::TypeReference) -> syn::TypeReference { + let mut tref = tref.to_owned(); + tref.lifetime = None; + tref +} + +/// Replace `Self` keyword in type with `cls` +pub(crate) fn replace_self(ty: &mut syn::Type, cls: &syn::Type) { + match ty { + syn::Type::Reference(tref) => replace_self(&mut tref.elem, cls), + syn::Type::Path(tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + *ty = cls.to_owned(); + } + } + } + _ => {} + } +}