From e140b729fc1eaa60b5aa5c8356e04b5a7cc3e624 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 23 Jun 2020 15:19:20 +0100 Subject: [PATCH 1/2] Allow #[getter] and #[setter] functions to take PyRef --- CHANGELOG.md | 1 + pyo3-derive-backend/src/method.rs | 362 ++++++++++++---------- pyo3-derive-backend/src/module.rs | 11 +- pyo3-derive-backend/src/pyclass.rs | 16 +- pyo3-derive-backend/src/pymethod.rs | 89 ++---- pyo3-derive-backend/src/pyproto.rs | 13 +- pyo3-derive-backend/src/utils.rs | 13 - tests/test_compile_error.rs | 2 + tests/test_getter_setter.rs | 30 ++ tests/ui/invalid_pymethod_names.stderr | 4 +- tests/ui/invalid_pymethod_receiver.rs | 11 + tests/ui/invalid_pymethod_receiver.stderr | 14 + 12 files changed, 321 insertions(+), 245 deletions(-) create mode 100644 tests/ui/invalid_pymethod_receiver.rs create mode 100644 tests/ui/invalid_pymethod_receiver.stderr diff --git a/CHANGELOG.md b/CHANGELOG.md index 44ca236c613..b9521b3af1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Fix passing explicit `None` to `Option` argument `#[pyfunction]` with a default value. [#936](https://github.com/PyO3/pyo3/pull/936) - Fix `PyClass.__new__`'s not respecting subclasses when inherited by a Python class. [#990](https://github.com/PyO3/pyo3/pull/990) - Fix returning `Option` from `#[pyproto]` methods. [#996](https://github.com/PyO3/pyo3/pull/996) +- Fix accepting `PyRef` and `PyRefMut` to `#[getter]` and `#[setter]` methods. [#999](https://github.com/PyO3/pyo3/pull/999) ## [0.10.1] - 2020-05-14 ### Fixed diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 5cb9606efe3..4c659b60581 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -4,8 +4,8 @@ use crate::pyfunction::Argument; use crate::pyfunction::{parse_name_attribute, PyFunctionAttr}; use crate::utils; use proc_macro2::TokenStream; -use quote::quote; use quote::ToTokens; +use quote::{quote, quote_spanned}; use syn::ext::IdentExt; use syn::spanned::Spanned; @@ -20,26 +20,72 @@ pub struct FnArg<'a> { pub reference: bool, } -#[derive(Clone, PartialEq, Debug)] -pub enum FnType { +#[derive(Clone, PartialEq, Debug, Copy, Eq)] +pub enum MethodTypeAttribute { + /// #[new] + New, + /// #[call] + Call, + /// #[classmethod] + ClassMethod, + /// #[classattr] + ClassAttribute, + /// #[staticmethod] + StaticMethod, + /// #[getter] Getter, + /// #[setter] Setter, - Fn, +} + +#[derive(Clone, PartialEq, Debug)] +pub enum FnType { + Getter(SelfType), + Setter(SelfType), + Fn(SelfType), + FnCall(SelfType), FnNew, - FnCall, FnClass, FnStatic, ClassAttribute, - /// For methods taht have `self_: &PyCell` instead of self receiver - PySelfRef(syn::TypeReference), - /// For methods taht have `self_: PyRef` or `PyRefMut` instead of self receiver - PySelfPath(syn::TypePath), +} + +#[derive(Clone, PartialEq, Debug)] +pub enum SelfType { + Receiver { mutable: bool }, + TryFromPyCell(syn::Type), +} + +impl SelfType { + pub fn receiver(&self, cls: &syn::Type) -> TokenStream { + match self { + SelfType::Receiver { mutable: false } => { + quote! { + let _cell = _py.from_borrowed_ptr::>(_slf); + let _ref = _cell.try_borrow()?; + let _slf = &_ref; + } + } + SelfType::Receiver { mutable: true } => { + quote! { + let _cell = _py.from_borrowed_ptr::>(_slf); + let mut _ref = _cell.try_borrow_mut()?; + let _slf = &mut _ref; + } + } + SelfType::TryFromPyCell(ty) => { + quote_spanned! { ty.span() => + let _cell = _py.from_borrowed_ptr::>(_slf); + let _slf = std::convert::TryFrom::try_from(_cell)?; + } + } + } + } } #[derive(Clone, PartialEq, Debug)] pub struct FnSpec<'a> { pub tp: FnType, - pub self_: Option, // Rust function name pub name: &'a syn::Ident, // Wrapped python name. This should not have any leading r#. @@ -58,15 +104,18 @@ pub fn get_return_info(output: &syn::ReturnType) -> syn::Type { } } -impl<'a> FnSpec<'a> { - /// Generate the code for borrowing self - pub(crate) fn borrow_self(&self) -> TokenStream { - let is_mut = self - .self_ - .expect("impl_borrow_self is called for non-self fn"); - crate::utils::borrow_self(is_mut) +pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result { + match arg { + syn::FnArg::Receiver(recv) => Ok(SelfType::Receiver { + mutable: recv.mutability.is_some(), + }), + syn::FnArg::Typed(syn::PatType { ref ty, .. }) => { + Ok(SelfType::TryFromPyCell(ty.as_ref().clone())) + } } +} +impl<'a> FnSpec<'a> { /// Parser function signature and function attributes pub fn parse( sig: &'a syn::Signature, @@ -75,27 +124,107 @@ impl<'a> FnSpec<'a> { ) -> syn::Result> { let name = &sig.ident; let MethodAttributes { - ty: mut fn_type, + ty: fn_type_attr, args: fn_attrs, mut python_name, } = parse_method_attributes(meth_attrs, allow_custom_name)?; - let mut self_ = None; let mut arguments = Vec::new(); - for input in sig.inputs.iter() { + let mut inputs_iter = sig.inputs.iter(); + + // Parse receiver & function type for various method types + let fn_type = match fn_type_attr { + Some(MethodTypeAttribute::StaticMethod) => FnType::FnStatic, + Some(MethodTypeAttribute::ClassAttribute) => { + if !sig.inputs.is_empty() { + return Err(syn::Error::new_spanned( + name, + "Class attribute methods cannot take arguments", + )); + } + FnType::ClassAttribute + } + Some(MethodTypeAttribute::New) => FnType::FnNew, + Some(MethodTypeAttribute::ClassMethod) => { + // Skip first argument for classmethod - always &PyType + let _ = inputs_iter.next(); + FnType::FnClass + } + Some(MethodTypeAttribute::Call) => FnType::FnCall( + inputs_iter + .next() + .ok_or_else(|| syn::Error::new_spanned(sig, "expected receiver for #[call]")) + .and_then(parse_method_receiver)?, + ), + Some(MethodTypeAttribute::Getter) => { + // Strip off "get_" prefix if needed + if python_name.is_none() { + const PREFIX: &str = "get_"; + + let ident = sig.ident.unraw().to_string(); + if ident.starts_with(PREFIX) { + python_name = Some(syn::Ident::new(&ident[PREFIX.len()..], ident.span())) + } + } + + FnType::Getter( + inputs_iter + .next() + .ok_or_else(|| { + syn::Error::new_spanned(sig, "expected receiver for #[getter]") + }) + .and_then(parse_method_receiver)?, + ) + } + Some(MethodTypeAttribute::Setter) => { + if python_name.is_none() { + const PREFIX: &str = "set_"; + + let ident = sig.ident.unraw().to_string(); + if ident.starts_with(PREFIX) { + python_name = Some(syn::Ident::new(&ident[PREFIX.len()..], ident.span())) + } + } + + FnType::Setter( + inputs_iter + .next() + .ok_or_else(|| { + syn::Error::new_spanned(sig, "expected receiver for #[setter]") + }) + .and_then(parse_method_receiver)?, + ) + } + None => { + FnType::Fn( + inputs_iter + .next() + .ok_or_else( + // No arguments - might be a static method? + || { + syn::Error::new_spanned( + sig, + "Static method needs #[staticmethod] attribute", + ) + }, + ) + .and_then(parse_method_receiver)?, + ) + } + }; + + // parse rest of arguments + for input in inputs_iter { match input { syn::FnArg::Receiver(recv) => { - self_ = Some(recv.mutability.is_some()); + return Err(syn::Error::new_spanned( + recv, + "Unexpected receiver for method", + )); } syn::FnArg::Typed(syn::PatType { ref pat, ref ty, .. }) => { - // skip first argument (cls) - if fn_type == FnType::FnClass && self_.is_none() { - self_ = Some(false); - continue; - } - let (ident, by_ref, mutability) = match **pat { syn::Pat::Ident(syn::PatIdent { ref ident, @@ -125,46 +254,6 @@ impl<'a> FnSpec<'a> { } let ty = get_return_info(&sig.output); - - if fn_type == FnType::Fn && self_.is_none() { - if arguments.is_empty() { - return Err(syn::Error::new_spanned( - name, - "Static method needs #[staticmethod] attribute", - )); - } - fn_type = match arguments.remove(0).ty { - syn::Type::Reference(r) => FnType::PySelfRef(replace_self_in_ref(r)?), - syn::Type::Path(p) => FnType::PySelfPath(replace_self_in_path(p)), - x => return Err(syn::Error::new_spanned(x, "Invalid type as custom self")), - }; - } - - if let FnType::ClassAttribute = &fn_type { - if self_.is_some() || !arguments.is_empty() { - return Err(syn::Error::new_spanned( - name, - "Class attribute methods cannot take arguments", - )); - } - } - - // "Tweak" getter / setter names: strip off set_ and get_ if needed - if let FnType::Getter | FnType::Setter = &fn_type { - if python_name.is_none() { - let prefix = match &fn_type { - FnType::Getter => "get_", - FnType::Setter => "set_", - _ => unreachable!(), - }; - - let ident = sig.ident.unraw().to_string(); - if ident.starts_with(prefix) { - python_name = Some(syn::Ident::new(&ident[prefix.len()..], ident.span())) - } - } - } - let python_name = python_name.unwrap_or_else(|| name.unraw()); let mut parse_erroneous_text_signature = |error_msg: &str| { @@ -179,16 +268,14 @@ impl<'a> FnSpec<'a> { }; let text_signature = match &fn_type { - FnType::Fn - | FnType::PySelfRef(_) - | FnType::PySelfPath(_) - | FnType::FnClass - | FnType::FnStatic => utils::parse_text_signature_attrs(&mut *meth_attrs, name)?, + FnType::Fn(_) | FnType::FnClass | FnType::FnStatic => { + utils::parse_text_signature_attrs(&mut *meth_attrs, name)? + } FnType::FnNew => parse_erroneous_text_signature( "text_signature not allowed on __new__; if you want to add a signature on \ __new__, put it on the struct definition instead", )?, - FnType::FnCall | FnType::Getter | FnType::Setter | FnType::ClassAttribute => { + FnType::FnCall(_) | FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => { parse_erroneous_text_signature("text_signature not allowed with this attribute")? } }; @@ -197,7 +284,6 @@ impl<'a> FnSpec<'a> { Ok(FnSpec { tp: fn_type, - self_, name, python_name, attrs: fn_attrs, @@ -311,7 +397,7 @@ pub(crate) fn check_ty_optional(ty: &syn::Type) -> Option<&syn::Type> { #[derive(Clone, PartialEq, Debug)] struct MethodAttributes { - ty: FnType, + ty: Option, args: Vec, python_name: Option, } @@ -322,27 +408,40 @@ fn parse_method_attributes( ) -> syn::Result { let mut new_attrs = Vec::new(); let mut args = Vec::new(); - let mut res: Option = None; + let mut ty: Option = None; let mut property_name = None; + macro_rules! set_ty { + ($new_ty:expr, $ident:expr) => { + if ty.is_some() { + return Err(syn::Error::new_spanned( + $ident, + "Cannot specify a second method type", + )); + } else { + ty = Some($new_ty); + } + }; + } + for attr in attrs.iter() { match attr.parse_meta()? { syn::Meta::Path(ref name) => { if name.is_ident("new") || name.is_ident("__new__") { - res = Some(FnType::FnNew) + set_ty!(MethodTypeAttribute::New, name); } else if name.is_ident("init") || name.is_ident("__init__") { return Err(syn::Error::new_spanned( name, "#[init] is disabled since PyO3 0.9.0", )); } else if name.is_ident("call") || name.is_ident("__call__") { - res = Some(FnType::FnCall) + set_ty!(MethodTypeAttribute::Call, name); } else if name.is_ident("classmethod") { - res = Some(FnType::FnClass) + set_ty!(MethodTypeAttribute::ClassMethod, name); } else if name.is_ident("staticmethod") { - res = Some(FnType::FnStatic) + set_ty!(MethodTypeAttribute::StaticMethod, name); } else if name.is_ident("classattr") { - res = Some(FnType::ClassAttribute) + set_ty!(MethodTypeAttribute::ClassAttribute, name); } else if name.is_ident("setter") || name.is_ident("getter") { if let syn::AttrStyle::Inner(_) = attr.style { return Err(syn::Error::new_spanned( @@ -350,16 +449,10 @@ fn parse_method_attributes( "Inner style attribute is not supported for setter and getter", )); } - if res != None { - return Err(syn::Error::new_spanned( - attr, - "setter/getter attribute can not be used mutiple times", - )); - } if name.is_ident("setter") { - res = Some(FnType::Setter) + set_ty!(MethodTypeAttribute::Setter, name); } else { - res = Some(FnType::Getter) + set_ty!(MethodTypeAttribute::Getter, name); } } else { new_attrs.push(attr.clone()) @@ -371,14 +464,14 @@ fn parse_method_attributes( .. }) => { if path.is_ident("new") { - res = Some(FnType::FnNew) + set_ty!(MethodTypeAttribute::New, path); } else if path.is_ident("init") { return Err(syn::Error::new_spanned( path, "#[init] is disabled since PyO3 0.9.0", )); } else if path.is_ident("call") { - res = Some(FnType::FnCall) + set_ty!(MethodTypeAttribute::Call, path); } else if path.is_ident("setter") || path.is_ident("getter") { if let syn::AttrStyle::Inner(_) = attr.style { return Err(syn::Error::new_spanned( @@ -386,12 +479,6 @@ fn parse_method_attributes( "Inner style attribute is not supported for setter and getter", )); } - if res != None { - return Err(syn::Error::new_spanned( - attr, - "setter/getter attribute can not be used mutiple times", - )); - } if nested.len() != 1 { return Err(syn::Error::new_spanned( attr, @@ -399,10 +486,10 @@ fn parse_method_attributes( )); } - res = if path.is_ident("setter") { - Some(FnType::Setter) + if path.is_ident("setter") { + set_ty!(MethodTypeAttribute::Setter, path); } else { - Some(FnType::Getter) + set_ty!(MethodTypeAttribute::Getter, path); }; property_name = match nested.first().unwrap() { @@ -439,9 +526,8 @@ fn parse_method_attributes( attrs.clear(); attrs.extend(new_attrs); - let ty = res.unwrap_or(FnType::Fn); let python_name = if allow_custom_name { - parse_method_name_attribute(&ty, attrs, property_name)? + parse_method_name_attribute(ty.as_ref(), attrs, property_name)? } else { property_name }; @@ -454,77 +540,33 @@ fn parse_method_attributes( } fn parse_method_name_attribute( - ty: &FnType, + ty: Option<&MethodTypeAttribute>, attrs: &mut Vec, property_name: Option, ) -> syn::Result> { + use MethodTypeAttribute::*; let name = parse_name_attribute(attrs)?; // Reject some invalid combinations if let Some(name) = &name { - match ty { - FnType::FnNew | FnType::FnCall | FnType::Getter | FnType::Setter => { - return Err(syn::Error::new_spanned( - name, - "name not allowed with this attribute", - )) + if let Some(ty) = ty { + match ty { + New | Call | Getter | Setter => { + return Err(syn::Error::new_spanned( + name, + "name not allowed with this method type", + )) + } + _ => {} } - _ => {} } } // Thanks to check above we can be sure that this generates the right python name Ok(match ty { - FnType::FnNew => Some(syn::Ident::new("__new__", proc_macro2::Span::call_site())), - FnType::FnCall => Some(syn::Ident::new("__call__", proc_macro2::Span::call_site())), - FnType::Getter | FnType::Setter => property_name, + Some(New) => Some(syn::Ident::new("__new__", proc_macro2::Span::call_site())), + Some(Call) => Some(syn::Ident::new("__call__", proc_macro2::Span::call_site())), + Some(Getter) | Some(Setter) => property_name, _ => name, }) } - -// Replace &A with &A<_> -fn replace_self_in_ref(refn: &syn::TypeReference) -> syn::Result { - let mut res = refn.to_owned(); - let tp = match &mut *res.elem { - syn::Type::Path(p) => p, - _ => return Err(syn::Error::new_spanned(refn, "unsupported argument")), - }; - replace_self_impl(tp); - res.lifetime = None; - Ok(res) -} - -fn replace_self_in_path(tp: &syn::TypePath) -> syn::TypePath { - let mut res = tp.to_owned(); - replace_self_impl(&mut res); - res -} - -fn replace_self_impl(tp: &mut syn::TypePath) { - for seg in &mut tp.path.segments { - if let syn::PathArguments::AngleBracketed(ref mut g) = seg.arguments { - let mut args = syn::punctuated::Punctuated::new(); - for arg in &g.args { - let mut add_arg = true; - if let syn::GenericArgument::Lifetime(_) = arg { - add_arg = false; - } - if let syn::GenericArgument::Type(syn::Type::Path(p)) = arg { - if p.path.segments.len() == 1 && p.path.segments[0].ident == "Self" { - args.push(infer(p.span())); - add_arg = false; - } - } - if add_arg { - args.push(arg.clone()); - } - } - g.args = args; - } - } - fn infer(span: proc_macro2::Span) -> syn::GenericArgument { - syn::GenericArgument::Type(syn::Type::Infer(syn::TypeInfer { - underscore_token: syn::token::Underscore { spans: [span] }, - })) - } -} diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 54abda71708..e6d29f92e7f 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -140,12 +140,14 @@ pub fn add_fn_to_module( pyfn_attrs: Vec, ) -> syn::Result { let mut arguments = Vec::new(); - let mut self_ = None; for input in func.sig.inputs.iter() { match input { - syn::FnArg::Receiver(recv) => { - self_ = Some(recv.mutability.is_some()); + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new_spanned( + input, + "Unexpected receiver for #[pyfn]", + )) } syn::FnArg::Typed(ref cap) => { arguments.push(wrap_fn_argument(cap, &func.sig.ident)?); @@ -161,8 +163,7 @@ pub fn add_fn_to_module( let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); let spec = method::FnSpec { - tp: method::FnType::Fn, - self_, + tp: method::FnType::FnStatic, name: &function_wrapper_ident, python_name, attrs: pyfn_attrs, diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index f1e4100b6e8..cf52f334f1c 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -1,6 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use crate::method::FnType; +use crate::method::{FnType, SelfType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType, }; @@ -185,9 +185,9 @@ fn parse_descriptors(item: &mut syn::Field) -> syn::Result> { for meta in list.nested.iter() { if let syn::NestedMeta::Meta(ref metaitem) = meta { if metaitem.path().is_ident("get") { - descs.push(FnType::Getter); + descs.push(FnType::Getter(SelfType::Receiver { mutable: false })); } else if metaitem.path().is_ident("set") { - descs.push(FnType::Setter); + descs.push(FnType::Setter(SelfType::Receiver { mutable: true })); } else { return Err(syn::Error::new_spanned( metaitem, @@ -450,16 +450,16 @@ fn impl_descriptors( let doc = utils::get_doc(&field.attrs, None, true) .unwrap_or_else(|_| syn::LitStr::new(&name.to_string(), name.span())); - match *desc { - FnType::Getter => Ok(impl_py_getter_def( + match desc { + FnType::Getter(self_ty) => Ok(impl_py_getter_def( &name, &doc, - &impl_wrap_getter(&cls, PropertyType::Descriptor(&field))?, + &impl_wrap_getter(&cls, PropertyType::Descriptor(&field), &self_ty)?, )), - FnType::Setter => Ok(impl_py_setter_def( + FnType::Setter(self_ty) => Ok(impl_py_setter_def( &name, &doc, - &impl_wrap_setter(&cls, PropertyType::Descriptor(&field))?, + &impl_wrap_setter(&cls, PropertyType::Descriptor(&field), &self_ty)?, )), _ => unreachable!(), } diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index da08d75d953..dd4f845ba0c 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -1,6 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::konst::ConstSpec; -use crate::method::{FnArg, FnSpec, FnType}; +use crate::method::{FnArg, FnSpec, FnType, SelfType}; use crate::utils; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -19,30 +19,26 @@ pub fn gen_py_method( check_generic(sig)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, true)?; - Ok(match spec.tp { - FnType::Fn => impl_py_method_def(&spec, &impl_wrap(cls, &spec, true)), - FnType::PySelfRef(ref self_ty) => { - impl_py_method_def(&spec, &impl_wrap_pyslf(cls, &spec, self_ty, true)) - } - FnType::PySelfPath(ref self_ty) => { - impl_py_method_def(&spec, &impl_wrap_pyslf(cls, &spec, self_ty, true)) - } + Ok(match &spec.tp { + FnType::Fn(self_ty) => impl_py_method_def(&spec, &impl_wrap(cls, &spec, self_ty, true)), FnType::FnNew => impl_py_method_def_new(&spec, &impl_wrap_new(cls, &spec)), - FnType::FnCall => impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, false)), + FnType::FnCall(self_ty) => { + impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, self_ty, false)) + } FnType::FnClass => impl_py_method_def_class(&spec, &impl_wrap_class(cls, &spec)), FnType::FnStatic => impl_py_method_def_static(&spec, &impl_wrap_static(cls, &spec)), FnType::ClassAttribute => { impl_py_method_class_attribute(&spec, &impl_wrap_class_attribute(cls, &spec)) } - FnType::Getter => impl_py_getter_def( + FnType::Getter(self_ty) => impl_py_getter_def( &spec.python_name, &spec.doc, - &impl_wrap_getter(cls, PropertyType::Function(&spec))?, + &impl_wrap_getter(cls, PropertyType::Function(&spec), self_ty)?, ), - FnType::Setter => impl_py_setter_def( + FnType::Setter(self_ty) => impl_py_setter_def( &spec.python_name, &spec.doc, - &impl_wrap_setter(cls, PropertyType::Function(&spec))?, + &impl_wrap_setter(cls, PropertyType::Function(&spec), self_ty)?, ), }) } @@ -81,31 +77,14 @@ pub fn gen_py_const( } /// Generate function wrapper (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_wrap(cls: &syn::Type, spec: &FnSpec<'_>, noargs: bool) -> TokenStream { - let body = impl_call(cls, &spec); - let borrow_self = spec.borrow_self(); - let slf = quote! { - let _slf = _py.from_borrowed_ptr::>(_slf); - #borrow_self - }; - impl_wrap_common(cls, spec, noargs, slf, body) -} - -pub fn impl_wrap_pyslf( +pub fn impl_wrap( cls: &syn::Type, spec: &FnSpec<'_>, - self_ty: impl quote::ToTokens, + self_ty: &SelfType, noargs: bool, ) -> TokenStream { - let names = get_arg_names(spec); - let name = &spec.name; - let body = quote! { - #cls::#name(_slf, #(#names),*) - }; - let slf = quote! { - let _cell = _py.from_borrowed_ptr::>(_slf); - let _slf: #self_ty = std::convert::TryFrom::try_from(_cell)?; - }; + let body = impl_call(cls, &spec); + let slf = self_ty.receiver(cls); impl_wrap_common(cls, spec, noargs, slf, body) } @@ -156,11 +135,11 @@ fn impl_wrap_common( } /// Generate function wrapper for protocol method (PyCFunction, PyCFunctionWithKeywords) -pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { +pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>, self_ty: &SelfType) -> TokenStream { let python_name = &spec.python_name; let cb = impl_call(cls, &spec); let body = impl_arg_params(&spec, cb); - let borrow_self = spec.borrow_self(); + let slf = self_ty.receiver(cls); quote! { #[allow(unused_mut)] @@ -171,8 +150,7 @@ pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { { const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); pyo3::callback_body_without_convert!(_py, { - let _slf = _py.from_borrowed_ptr::>(_slf); - #borrow_self + #slf let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); @@ -282,7 +260,7 @@ pub fn impl_wrap_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStr } } -fn impl_call_getter(spec: &FnSpec) -> syn::Result { +fn impl_call_getter(cls: &syn::Type, spec: &FnSpec) -> syn::Result { let (py_arg, args) = split_off_python_arg(&spec.args); if !args.is_empty() { return Err(syn::Error::new_spanned( @@ -293,10 +271,11 @@ fn impl_call_getter(spec: &FnSpec) -> syn::Result { let name = &spec.name; let fncall = if py_arg.is_some() { - quote! { _slf.#name(_py) } + quote!(#cls::#name(_slf, _py)) } else { - quote! { _slf.#name() } + quote!(#cls::#name(_slf)) }; + Ok(fncall) } @@ -304,6 +283,7 @@ fn impl_call_getter(spec: &FnSpec) -> syn::Result { pub(crate) fn impl_wrap_getter( cls: &syn::Type, property_type: PropertyType, + self_ty: &SelfType, ) -> syn::Result { let (python_name, getter_impl) = match property_type { PropertyType::Descriptor(field) => { @@ -315,25 +295,24 @@ pub(crate) fn impl_wrap_getter( }), ) } - PropertyType::Function(spec) => (spec.python_name.clone(), impl_call_getter(&spec)?), + PropertyType::Function(spec) => (spec.python_name.clone(), impl_call_getter(cls, spec)?), }; - let borrow_self = crate::utils::borrow_self(false); + let slf = self_ty.receiver(cls); Ok(quote! { unsafe extern "C" fn __wrap( _slf: *mut pyo3::ffi::PyObject, _: *mut ::std::os::raw::c_void) -> *mut pyo3::ffi::PyObject { const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); pyo3::callback_body_without_convert!(_py, { - let _slf = _py.from_borrowed_ptr::>(_slf); - #borrow_self + #slf pyo3::callback::convert(_py, #getter_impl) }) } }) } -fn impl_call_setter(spec: &FnSpec) -> syn::Result { +fn impl_call_setter(cls: &syn::Type, spec: &FnSpec) -> syn::Result { let (py_arg, args) = split_off_python_arg(&spec.args); if args.is_empty() { @@ -350,9 +329,9 @@ fn impl_call_setter(spec: &FnSpec) -> syn::Result { let name = &spec.name; let fncall = if py_arg.is_some() { - quote!(_slf.#name(_py, _val)) + quote!(#cls::#name(_slf, _py, _val)) } else { - quote!(_slf.#name(_val)) + quote!(#cls::#name(_slf, _val)) }; Ok(fncall) @@ -362,16 +341,17 @@ fn impl_call_setter(spec: &FnSpec) -> syn::Result { pub(crate) fn impl_wrap_setter( cls: &syn::Type, property_type: PropertyType, + self_ty: &SelfType, ) -> syn::Result { let (python_name, setter_impl) = match property_type { PropertyType::Descriptor(field) => { let name = field.ident.as_ref().unwrap(); (name.unraw(), quote!({ _slf.#name = _val; })) } - PropertyType::Function(spec) => (spec.python_name.clone(), impl_call_setter(&spec)?), + PropertyType::Function(spec) => (spec.python_name.clone(), impl_call_setter(cls, spec)?), }; - let borrow_self = crate::utils::borrow_self(true); + let slf = self_ty.receiver(cls); Ok(quote! { #[allow(unused_mut)] unsafe extern "C" fn __wrap( @@ -380,8 +360,7 @@ pub(crate) fn impl_wrap_setter( { const _LOCATION: &'static str = concat!(stringify!(#cls),".",stringify!(#python_name),"()"); pyo3::callback_body_without_convert!(_py, { - let _slf = _py.from_borrowed_ptr::>(_slf); - #borrow_self + #slf let _value = _py.from_borrowed_ptr::(_value); let _val = pyo3::FromPyObject::extract(_value)?; @@ -398,10 +377,10 @@ pub fn get_arg_names(spec: &FnSpec) -> Vec { .collect() } -fn impl_call(_cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { +fn impl_call(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let fname = &spec.name; let names = get_arg_names(spec); - quote! { _slf.#fname(#(#names),*) } + quote! { #cls::#fname(_slf, #(#names),*) } } pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream { diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index 1cded509c3e..308bcbe3428 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -2,7 +2,7 @@ use crate::defs; use crate::func::impl_method_proto; -use crate::method::FnSpec; +use crate::method::{FnSpec, FnType}; use crate::pymethod; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -75,7 +75,16 @@ fn impl_proto_impl( if let Some(m) = proto.get_method(&met.sig.ident) { let name = &met.sig.ident; let fn_spec = FnSpec::parse(&met.sig, &mut met.attrs, false)?; - let method = pymethod::impl_proto_wrap(ty, &fn_spec); + + let method = if let FnType::Fn(self_ty) = &fn_spec.tp { + pymethod::impl_proto_wrap(ty, &fn_spec, &self_ty) + } else { + return Err(syn::Error::new_spanned( + &met.sig, + "Expected method with receiver for #[pyproto] method", + )); + }; + let coexist = if m.can_coexist { // We need METH_COEXIST here to prevent __add__ from overriding __radd__ quote!(pyo3::ffi::METH_COEXIST) diff --git a/pyo3-derive-backend/src/utils.rs b/pyo3-derive-backend/src/utils.rs index d051a63409e..6bd78570495 100644 --- a/pyo3-derive-backend/src/utils.rs +++ b/pyo3-derive-backend/src/utils.rs @@ -1,21 +1,8 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use proc_macro2::Span; use proc_macro2::TokenStream; -use quote::quote; use std::fmt::Display; -pub(crate) fn borrow_self(is_mut: bool) -> TokenStream { - if is_mut { - quote! { - let mut _slf = _slf.try_borrow_mut()?; - } - } else { - quote! { - let _slf = _slf.try_borrow()?; - } - } -} - pub fn print_err(msg: String, t: TokenStream) { println!("Error: {} in '{}'", msg, t.to_string()); } diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index f9a18eda3f8..1f7a94c4dad 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -8,6 +8,8 @@ fn test_compile_errors() { t.compile_fail("tests/ui/missing_clone.rs"); t.compile_fail("tests/ui/reject_generics.rs"); t.compile_fail("tests/ui/wrong_aspyref_lifetimes.rs"); + t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + t.compile_fail("tests/ui/invalid_pymethod_receiver.rs"); skip_min_stable(&t); diff --git a/tests/test_getter_setter.rs b/tests/test_getter_setter.rs index 7610e81fa5b..747a9014c4b 100644 --- a/tests/test_getter_setter.rs +++ b/tests/test_getter_setter.rs @@ -106,3 +106,33 @@ fn getter_setter_autogen() { "assert inst.text == 'Hello'; inst.text = 'There'; assert inst.text == 'There'" ); } + +#[pyclass] +struct RefGetterSetter { + num: i32, +} + +#[pymethods] +impl RefGetterSetter { + #[getter] + fn get_num(slf: PyRef) -> i32 { + slf.num + } + + #[setter] + fn set_num(mut slf: PyRefMut, value: i32) { + slf.num = value; + } +} + +#[test] +fn ref_getter_setter() { + // Regression test for #837 + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new(py, RefGetterSetter { num: 10 }).unwrap(); + + py_run!(py, inst, "assert inst.num == 10"); + py_run!(py, inst, "inst.num = 20; assert inst.num == 20"); +} diff --git a/tests/ui/invalid_pymethod_names.stderr b/tests/ui/invalid_pymethod_names.stderr index 17f3595b0b1..0a56f9e6435 100644 --- a/tests/ui/invalid_pymethod_names.stderr +++ b/tests/ui/invalid_pymethod_names.stderr @@ -1,4 +1,4 @@ -error: name not allowed with this attribute +error: name not allowed with this method type --> $DIR/invalid_pymethod_names.rs:10:5 | 10 | #[name = "num"] @@ -10,7 +10,7 @@ error: #[name] can not be specified multiple times 17 | #[name = "foo"] | ^ -error: name not allowed with this attribute +error: name not allowed with this method type --> $DIR/invalid_pymethod_names.rs:24:5 | 24 | #[name = "makenew"] diff --git a/tests/ui/invalid_pymethod_receiver.rs b/tests/ui/invalid_pymethod_receiver.rs new file mode 100644 index 00000000000..f85d30d5a0b --- /dev/null +++ b/tests/ui/invalid_pymethod_receiver.rs @@ -0,0 +1,11 @@ +use pyo3::prelude::*; + +#[pyclass] +struct MyClass {} + +#[pymethods] +impl MyClass { + fn method_with_invalid_self_type(slf: i32, py: Python, index: u32) {} +} + +fn main() {} diff --git a/tests/ui/invalid_pymethod_receiver.stderr b/tests/ui/invalid_pymethod_receiver.stderr new file mode 100644 index 00000000000..8ae6a4cd66d --- /dev/null +++ b/tests/ui/invalid_pymethod_receiver.stderr @@ -0,0 +1,14 @@ +error[E0277]: the trait bound `i32: std::convert::From<&pyo3::pycell::PyCell>` is not satisfied + --> $DIR/invalid_pymethod_receiver.rs:8:43 + | +8 | fn method_with_invalid_self_type(slf: i32, py: Python, index: u32) {} + | ^^^ the trait `std::convert::From<&pyo3::pycell::PyCell>` is not implemented for `i32` + | + = help: the following implementations were found: + > + > + > + > + and 2 others + = note: required because of the requirements on the impl of `std::convert::Into` for `&pyo3::pycell::PyCell` + = note: required because of the requirements on the impl of `std::convert::TryFrom<&pyo3::pycell::PyCell>` for `i32` From c3e993e5a6c5617fe86668dd47cde288351f5ab3 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 27 Jun 2020 14:49:46 +0100 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Yuji Kanagawa --- pyo3-derive-backend/src/method.rs | 100 +++++++++++------------------- tests/test_compile_error.rs | 4 +- 2 files changed, 39 insertions(+), 65 deletions(-) diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 4c659b60581..26a18a8503d 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -132,6 +132,23 @@ impl<'a> FnSpec<'a> { let mut arguments = Vec::new(); let mut inputs_iter = sig.inputs.iter(); + let mut parse_receiver = |msg: &'static str| { + inputs_iter + .next() + .ok_or_else(|| syn::Error::new_spanned(sig, msg)) + .and_then(parse_method_receiver) + }; + + // strip get_ or set_ + let strip_fn_name = |prefix: &'static str| { + let ident = sig.ident.unraw().to_string(); + if ident.starts_with(prefix) { + Some(syn::Ident::new(&ident[prefix.len()..], ident.span())) + } else { + None + } + }; + // Parse receiver & function type for various method types let fn_type = match fn_type_attr { Some(MethodTypeAttribute::StaticMethod) => FnType::FnStatic, @@ -150,67 +167,28 @@ impl<'a> FnSpec<'a> { let _ = inputs_iter.next(); FnType::FnClass } - Some(MethodTypeAttribute::Call) => FnType::FnCall( - inputs_iter - .next() - .ok_or_else(|| syn::Error::new_spanned(sig, "expected receiver for #[call]")) - .and_then(parse_method_receiver)?, - ), + Some(MethodTypeAttribute::Call) => { + FnType::FnCall(parse_receiver("expected receiver for #[call]")?) + } Some(MethodTypeAttribute::Getter) => { // Strip off "get_" prefix if needed if python_name.is_none() { - const PREFIX: &str = "get_"; - - let ident = sig.ident.unraw().to_string(); - if ident.starts_with(PREFIX) { - python_name = Some(syn::Ident::new(&ident[PREFIX.len()..], ident.span())) - } + python_name = strip_fn_name("get_"); } - FnType::Getter( - inputs_iter - .next() - .ok_or_else(|| { - syn::Error::new_spanned(sig, "expected receiver for #[getter]") - }) - .and_then(parse_method_receiver)?, - ) + FnType::Getter(parse_receiver("expected receiver for #[getter]")?) } Some(MethodTypeAttribute::Setter) => { + // Strip off "set_" prefix if needed if python_name.is_none() { - const PREFIX: &str = "set_"; - - let ident = sig.ident.unraw().to_string(); - if ident.starts_with(PREFIX) { - python_name = Some(syn::Ident::new(&ident[PREFIX.len()..], ident.span())) - } + python_name = strip_fn_name("set_"); } - FnType::Setter( - inputs_iter - .next() - .ok_or_else(|| { - syn::Error::new_spanned(sig, "expected receiver for #[setter]") - }) - .and_then(parse_method_receiver)?, - ) - } - None => { - FnType::Fn( - inputs_iter - .next() - .ok_or_else( - // No arguments - might be a static method? - || { - syn::Error::new_spanned( - sig, - "Static method needs #[staticmethod] attribute", - ) - }, - ) - .and_then(parse_method_receiver)?, - ) + FnType::Setter(parse_receiver("expected receiver for #[setter]")?) } + None => FnType::Fn(parse_receiver( + "Static method needs #[staticmethod] attribute", + )?), }; // parse rest of arguments @@ -413,13 +391,11 @@ fn parse_method_attributes( macro_rules! set_ty { ($new_ty:expr, $ident:expr) => { - if ty.is_some() { + if ty.replace($new_ty).is_some() { return Err(syn::Error::new_spanned( $ident, "Cannot specify a second method type", )); - } else { - ty = Some($new_ty); } }; } @@ -548,17 +524,15 @@ fn parse_method_name_attribute( let name = parse_name_attribute(attrs)?; // Reject some invalid combinations - if let Some(name) = &name { - if let Some(ty) = ty { - match ty { - New | Call | Getter | Setter => { - return Err(syn::Error::new_spanned( - name, - "name not allowed with this method type", - )) - } - _ => {} + if let (Some(name), Some(ty)) = (&name, ty) { + match ty { + New | Call | Getter | Setter => { + return Err(syn::Error::new_spanned( + name, + "name not allowed with this method type", + )) } + _ => {} } } diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 1f7a94c4dad..5d839fa5c23 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -5,11 +5,11 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_macro_args.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); + t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + t.compile_fail("tests/ui/invalid_pymethod_receiver.rs"); t.compile_fail("tests/ui/missing_clone.rs"); t.compile_fail("tests/ui/reject_generics.rs"); t.compile_fail("tests/ui/wrong_aspyref_lifetimes.rs"); - t.compile_fail("tests/ui/invalid_pymethod_names.rs"); - t.compile_fail("tests/ui/invalid_pymethod_receiver.rs"); skip_min_stable(&t);