Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip type check of method receivers where unnecessary #4026

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 54 additions & 18 deletions pyo3-macros-backend/src/method.rs
Expand Up @@ -91,6 +91,22 @@ pub enum FnType {
ClassAttribute,
}

/// Whether method receiver type should be assumed correct.
#[derive(Clone, Copy)]
pub enum AssumeCorrectReceiverType {
/// Use this only when the Python interpreter is expected to guarantee the
/// correct receiver type.
///
/// Methods, getters, setters, and many slots will have the correct receiver type
/// enforced by the Python interpreter, so this is safe to use in those cases.
Yes,
/// Use this when the receiver type must be checked.
///
/// Reversible numeric operators are a good example where the receiver type
/// must be checked.
No,
}

impl FnType {
pub fn skip_first_rust_argument_in_python_signature(&self) -> bool {
match self {
Expand All @@ -109,6 +125,7 @@ impl FnType {
cls: Option<&syn::Type>,
error_mode: ExtractErrorMode,
holders: &mut Holders,
assume_correct_receiver_type: AssumeCorrectReceiverType,
ctx: &Ctx,
) -> TokenStream {
let Ctx { pyo3_path } = ctx;
Expand All @@ -118,6 +135,7 @@ impl FnType {
cls.expect("no class given for Fn with a \"self\" receiver"),
error_mode,
holders,
assume_correct_receiver_type,
ctx,
);
syn::Token![,](Span::call_site()).to_tokens(&mut receiver);
Expand Down Expand Up @@ -187,6 +205,7 @@ impl SelfType {
cls: &syn::Type,
error_mode: ExtractErrorMode,
holders: &mut Holders,
assume_correct_receiver_type: AssumeCorrectReceiverType,
ctx: &Ctx,
) -> TokenStream {
// Due to use of quote_spanned in this function, need to bind these idents to the
Expand All @@ -196,36 +215,41 @@ impl SelfType {
let Ctx { pyo3_path } = ctx;
match self {
SelfType::Receiver { span, mutable } => {
let method = if *mutable {
syn::Ident::new("extract_pyclass_ref_mut", *span)
} else {
syn::Ident::new("extract_pyclass_ref", *span)
let method_name = match (mutable, assume_correct_receiver_type) {
(true, AssumeCorrectReceiverType::Yes) => "receive_pyclass_mut",
(true, AssumeCorrectReceiverType::No) => "receive_pyclass_mut_checked_downcast",
(false, AssumeCorrectReceiverType::Yes) => "receive_pyclass",
(false, AssumeCorrectReceiverType::No) => "receive_pyclass_checked_downcast",
};
let method = syn::Ident::new(method_name, *span);
let holder = holders.push_holder(*span);
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
error_mode.handle_error(
quote_spanned! { *span =>
#pyo3_path::impl_::extract_argument::#method::<#cls>(
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf).0,
#pyo3_path::impl_::pymethods::#method::<#cls>(
#py,
&#slf,
&mut #holder,
)
},
ctx,
)
}
SelfType::TryFromBoundRef(span) => {
let method_name = match assume_correct_receiver_type {
AssumeCorrectReceiverType::Yes => "receive_pyclass_try_into",
AssumeCorrectReceiverType::No => "receive_pyclass_try_into_checked_downcast",
};
let method = syn::Ident::new(method_name, *span);
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
error_mode.handle_error(
quote_spanned! { *span =>
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf).downcast::<#cls>()
.map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
.and_then(
#[allow(unknown_lints, clippy::unnecessary_fallible_conversions)] // In case slf is Py<Self> (unknown_lints can be removed when MSRV is 1.75+)
|bound| ::std::convert::TryFrom::try_from(bound).map_err(::std::convert::Into::into)
)

#pyo3_path::impl_::pymethods::#method::<#cls, _>(
#py,
&#slf,
)
},
ctx
ctx,
)
}
}
Expand Down Expand Up @@ -505,7 +529,15 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
let mut self_arg = || {
self.tp.self_arg(
cls,
ExtractErrorMode::Raise,
holders,
AssumeCorrectReceiverType::Yes,
ctx,
)
};

let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
Expand Down Expand Up @@ -692,9 +724,13 @@ impl<'a> FnSpec<'a> {
CallingConvention::TpNew => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let self_arg = self
.tp
.self_arg(cls, ExtractErrorMode::Raise, &mut holders, ctx);
let self_arg = self.tp.self_arg(
cls,
ExtractErrorMode::Raise,
&mut holders,
AssumeCorrectReceiverType::Yes,
ctx,
);
let call = quote! { #rust_name(#self_arg #(#args),*) };
let init_holders = holders.init_holders(ctx);
let check_gil_refs = holders.check_gil_refs();
Expand Down
64 changes: 56 additions & 8 deletions pyo3-macros-backend/src/pymethod.rs
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use crate::attributes::{NameAttribute, RenamingRule};
use crate::method::{CallingConvention, ExtractErrorMode};
use crate::method::{AssumeCorrectReceiverType, CallingConvention, ExtractErrorMode};
use crate::params::Holders;
use crate::utils::Ctx;
use crate::utils::PythonDoc;
Expand Down Expand Up @@ -515,7 +515,13 @@ fn impl_call_setter(
ctx: &Ctx,
) -> syn::Result<TokenStream> {
let (py_arg, args) = split_off_python_arg(&spec.signature.arguments);
let slf = self_type.receiver(cls, ExtractErrorMode::Raise, holders, ctx);
let slf = self_type.receiver(
cls,
ExtractErrorMode::Raise,
holders,
AssumeCorrectReceiverType::Yes,
ctx,
);

if args.is_empty() {
bail_spanned!(spec.name.span() => "setter function expected to have one argument");
Expand Down Expand Up @@ -554,7 +560,13 @@ pub fn impl_py_setter_def(
mutable: true,
span: Span::call_site(),
}
.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);
.receiver(
cls,
ExtractErrorMode::Raise,
&mut holders,
AssumeCorrectReceiverType::Yes,
ctx,
);
if let Some(ident) = &field.ident {
// named struct field
quote!({ #slf.#ident = _val; })
Expand Down Expand Up @@ -687,7 +699,13 @@ fn impl_call_getter(
ctx: &Ctx,
) -> syn::Result<TokenStream> {
let (py_arg, args) = split_off_python_arg(&spec.signature.arguments);
let slf = self_type.receiver(cls, ExtractErrorMode::Raise, holders, ctx);
let slf = self_type.receiver(
cls,
ExtractErrorMode::Raise,
holders,
AssumeCorrectReceiverType::Yes,
ctx,
);
ensure_spanned!(
args.is_empty(),
args[0].ty.span() => "getter function can only have one argument (of type pyo3::Python)"
Expand Down Expand Up @@ -722,7 +740,13 @@ pub fn impl_py_getter_def(
mutable: false,
span: Span::call_site(),
}
.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);
.receiver(
cls,
ExtractErrorMode::Raise,
&mut holders,
AssumeCorrectReceiverType::Yes,
ctx,
);
let field_token = if let Some(ident) = &field.ident {
// named struct field
ident.to_token_stream()
Expand Down Expand Up @@ -1250,6 +1274,7 @@ impl SlotDef {
*extract_error_mode,
&mut holders,
return_mode.as_ref(),
AssumeCorrectReceiverType::Yes,
ctx,
)?;
let name = spec.name;
Expand Down Expand Up @@ -1298,12 +1323,17 @@ fn generate_method_body(
extract_error_mode: ExtractErrorMode,
holders: &mut Holders,
return_mode: Option<&ReturnMode>,
assume_correct_receiver_type: AssumeCorrectReceiverType,
ctx: &Ctx,
) -> Result<TokenStream> {
let Ctx { pyo3_path } = ctx;
let self_arg = spec
.tp
.self_arg(Some(cls), extract_error_mode, holders, ctx);
let self_arg = spec.tp.self_arg(
Some(cls),
extract_error_mode,
holders,
assume_correct_receiver_type,
ctx,
);
let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders, ctx)?;
let call = quote! { #cls::#rust_name(#self_arg #(#args),*) };
Expand All @@ -1323,6 +1353,7 @@ struct SlotFragmentDef {
fragment: &'static str,
arguments: &'static [Ty],
extract_error_mode: ExtractErrorMode,
assume_correct_receiver_type: AssumeCorrectReceiverType,
ret_ty: Ty,
}

Expand All @@ -1332,6 +1363,7 @@ impl SlotFragmentDef {
fragment,
arguments,
extract_error_mode: ExtractErrorMode::Raise,
assume_correct_receiver_type: AssumeCorrectReceiverType::Yes,
ret_ty: Ty::Void,
}
}
Expand All @@ -1341,6 +1373,11 @@ impl SlotFragmentDef {
self
}

const fn no_assume_correct_receiver_type(mut self) -> Self {
self.assume_correct_receiver_type = AssumeCorrectReceiverType::No;
self
}

const fn ret_ty(mut self, ret_ty: Ty) -> Self {
self.ret_ty = ret_ty;
self
Expand All @@ -1357,6 +1394,7 @@ impl SlotFragmentDef {
fragment,
arguments,
extract_error_mode,
assume_correct_receiver_type,
ret_ty,
} = self;
let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment);
Expand All @@ -1374,6 +1412,7 @@ impl SlotFragmentDef {
*extract_error_mode,
&mut holders,
None,
*assume_correct_receiver_type,
ctx,
)?;
let ret_ty = ret_ty.ffi_type(ctx);
Expand Down Expand Up @@ -1424,6 +1463,7 @@ macro_rules! binary_num_slot_fragment_def {
($ident:ident, $name:literal) => {
const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
};
}
Expand Down Expand Up @@ -1457,28 +1497,36 @@ binary_num_slot_fragment_def!(__ROR__, "__ror__");

const __POW__: SlotFragmentDef = SlotFragmentDef::new("__pow__", &[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);

const __LT__: SlotFragmentDef = SlotFragmentDef::new("__lt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __LE__: SlotFragmentDef = SlotFragmentDef::new("__le__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __EQ__: SlotFragmentDef = SlotFragmentDef::new("__eq__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __NE__: SlotFragmentDef = SlotFragmentDef::new("__ne__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __GT__: SlotFragmentDef = SlotFragmentDef::new("__gt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);
const __GE__: SlotFragmentDef = SlotFragmentDef::new("__ge__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.no_assume_correct_receiver_type()
.ret_ty(Ty::Object);

fn extract_proto_arguments(
Expand Down