Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt: make argument extraction code smaller #2075

Merged
merged 1 commit into from Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Expand Up @@ -41,8 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
accompanies your error type in your crate's documentation.
- Improve performance and error messages for `#[derive(FromPyObject)]` for enums. [#2068](https://github.com/PyO3/pyo3/pull/2068)
- Reduce generated LLVM code size (to improve compile times) for:
- internal `handle_panic` helper [#2073](https://github.com/PyO3/pyo3/pull/2073)
- `#[pyclass]` type object creation [#2075](https://github.com/PyO3/pyo3/pull/2075)
- internal `handle_panic` helper [#2074](https://github.com/PyO3/pyo3/pull/2074)
- `#[pyfunction]` and `#[pymethods]` argument extraction [#2075](https://github.com/PyO3/pyo3/pull/2075)
- `#[pyclass]` type object creation [#2076](https://github.com/PyO3/pyo3/pull/2076)

### Removed

Expand Down
12 changes: 6 additions & 6 deletions pyo3-macros-backend/src/params.rs
Expand Up @@ -213,8 +213,9 @@ fn impl_arg_param(

let ty = arg.ty;
let name = arg.name;
let name_str = name.to_string();
let transform_error = quote! {
|e| _pyo3::derive_utils::argument_extraction_error(#py, stringify!(#name), e)
|e| _pyo3::impl_::extract_argument::argument_extraction_error(#py, #name_str, e)
};

if is_args(&spec.attrs, name) {
Expand All @@ -223,17 +224,16 @@ fn impl_arg_param(
arg.name.span() => "args cannot be optional"
);
return Ok(quote_arg_span! {
let #arg_name = _args.unwrap().extract().map_err(#transform_error)?;
let #arg_name = _pyo3::impl_::extract_argument::extract_argument(_args.unwrap(), #name_str)?;
});
} else if is_kwargs(&spec.attrs, name) {
ensure_spanned!(
arg.optional.is_some(),
arg.name.span() => "kwargs must be Option<_>"
);
return Ok(quote_arg_span! {
let #arg_name = _kwargs.map(|kwargs| kwargs.extract())
.transpose()
.map_err(#transform_error)?;
let #arg_name = _kwargs.map(|kwargs| _pyo3::impl_::extract_argument::extract_argument(kwargs, #name_str))
.transpose()?;
});
}

Expand All @@ -243,7 +243,7 @@ fn impl_arg_param(
let extract = if let Some(FromPyWithAttribute(expr_path)) = &arg.attrs.from_py_with {
quote_arg_span! { #expr_path(_obj).map_err(#transform_error) }
} else {
quote_arg_span! { _obj.extract().map_err(#transform_error) }
quote_arg_span! { _pyo3::impl_::extract_argument::extract_argument(_obj, #name_str) }
};

let arg_value_or_default = match (spec.default_value(name), arg.optional.is_some()) {
Expand Down
3 changes: 3 additions & 0 deletions pytests/pyo3-benchmarks/tox.ini
Expand Up @@ -3,3 +3,6 @@ usedevelop = True
description = Run the unit tests under {basepython}
deps = -rrequirements-dev.txt
commands = pytest --benchmark-sort=name {posargs}
# Use recreate so that tox always rebuilds, otherwise changes to Rust are not
# picked up.
recreate = True
141 changes: 77 additions & 64 deletions src/derive_utils.rs
Expand Up @@ -102,100 +102,77 @@ impl FunctionDescription {
varkeywords
}
(Some(kwargs), false) => {
self.extract_keyword_arguments(kwargs, output, |name, _| {
Err(self.unexpected_keyword_argument(name))
})?;
self.extract_keyword_arguments(
kwargs,
output,
#[cold]
|name, _| Err(self.unexpected_keyword_argument(name)),
)?;
None
}
(None, _) => None,
};

// Check that there's sufficient positional arguments once keyword arguments are specified
if args_provided < self.required_positional_parameters {
let missing_positional_arguments: Vec<_> = self
.positional_parameter_names
.iter()
.take(self.required_positional_parameters)
.zip(output.iter())
.filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None })
.collect();
if !missing_positional_arguments.is_empty() {
return Err(
self.missing_required_arguments("positional", &missing_positional_arguments)
);
for out in &output[..self.required_positional_parameters] {
if out.is_none() {
return Err(self.missing_required_positional_arguments(output));
}
}
}

// Check no missing required keyword arguments
let missing_keyword_only_arguments: Vec<_> = self
.keyword_only_parameters
.iter()
.zip(&output[num_positional_parameters..])
.filter_map(|(keyword_desc, out)| {
if keyword_desc.required && out.is_none() {
Some(keyword_desc.name)
} else {
None
}
})
.collect();

if !missing_keyword_only_arguments.is_empty() {
return Err(self.missing_required_arguments("keyword", &missing_keyword_only_arguments));
let keyword_output = &output[num_positional_parameters..];
for (param, out) in self.keyword_only_parameters.iter().zip(keyword_output) {
if param.required && out.is_none() {
return Err(self.missing_required_keyword_arguments(keyword_output));
}
}

Ok((varargs, varkeywords))
}

#[inline]
fn extract_keyword_arguments<'p>(
&self,
kwargs: impl Iterator<Item = (&'p PyAny, &'p PyAny)>,
output: &mut [Option<&'p PyAny>],
mut unexpected_keyword_handler: impl FnMut(&'p PyAny, &'p PyAny) -> PyResult<()>,
) -> PyResult<()> {
let (args_output, kwargs_output) =
output.split_at_mut(self.positional_parameter_names.len());
let positional_args_count = self.positional_parameter_names.len();
let mut positional_only_keyword_arguments = Vec::new();
for (kwarg_name, value) in kwargs {
let utf8_string = match kwarg_name.downcast::<PyString>()?.to_str() {
Ok(utf8_string) => utf8_string,
'for_each_kwarg: for (kwarg_name_py, value) in kwargs {
let kwarg_name = match kwarg_name_py.downcast::<PyString>()?.to_str() {
Ok(kwarg_name) => kwarg_name,
// This keyword is not a UTF8 string: all PyO3 argument names are guaranteed to be
// UTF8 by construction.
Err(_) => {
unexpected_keyword_handler(kwarg_name, value)?;
unexpected_keyword_handler(kwarg_name_py, value)?;
continue;
}
};

// Compare the keyword name against each parameter in turn. This is exactly the same method
// which CPython uses to map keyword names. Although it's O(num_parameters), the number of
// parameters is expected to be small so it's not worth constructing a mapping.
if let Some(i) = self
.keyword_only_parameters
.iter()
.position(|param| utf8_string == param.name)
{
kwargs_output[i] = Some(value);
continue;
for (i, param) in self.keyword_only_parameters.iter().enumerate() {
if param.name == kwarg_name {
output[positional_args_count + i] = Some(value);
continue 'for_each_kwarg;
}
}

// Repeat for positional parameters
if let Some((i, param)) = self
.positional_parameter_names
.iter()
.enumerate()
.find(|&(_, param)| utf8_string == *param)
{
if let Some(i) = self.find_keyword_parameter_in_positionals(kwarg_name) {
if i < self.positional_only_parameters {
positional_only_keyword_arguments.push(*param);
} else if args_output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(param));
positional_only_keyword_arguments.push(kwarg_name);
} else if output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
}

unexpected_keyword_handler(kwarg_name, value)?;
unexpected_keyword_handler(kwarg_name_py, value)?;
}

if positional_only_keyword_arguments.is_empty() {
Expand All @@ -205,6 +182,16 @@ impl FunctionDescription {
}
}

fn find_keyword_parameter_in_positionals(&self, kwarg_name: &str) -> Option<usize> {
for (i, param_name) in self.positional_parameter_names.iter().enumerate() {
if *param_name == kwarg_name {
return Some(i);
}
}
None
}

#[cold]
fn too_many_positional_arguments(&self, args_provided: usize) -> PyErr {
let was = if args_provided == 1 { "was" } else { "were" };
let msg = if self.required_positional_parameters != self.positional_parameter_names.len() {
Expand All @@ -228,6 +215,7 @@ impl FunctionDescription {
PyTypeError::new_err(msg)
}

#[cold]
fn multiple_values_for_argument(&self, argument: &str) -> PyErr {
PyTypeError::new_err(format!(
"{} got multiple values for argument '{}'",
Expand All @@ -236,6 +224,7 @@ impl FunctionDescription {
))
}

#[cold]
fn unexpected_keyword_argument(&self, argument: &PyAny) -> PyErr {
PyTypeError::new_err(format!(
"{} got an unexpected keyword argument '{}'",
Expand All @@ -244,6 +233,7 @@ impl FunctionDescription {
))
}

#[cold]
fn positional_only_keyword_arguments(&self, parameter_names: &[&str]) -> PyErr {
let mut msg = format!(
"{} got some positional-only arguments passed as keyword arguments: ",
Expand All @@ -253,6 +243,7 @@ impl FunctionDescription {
PyTypeError::new_err(msg)
}

#[cold]
fn missing_required_arguments(&self, argument_type: &str, parameter_names: &[&str]) -> PyErr {
let arguments = if parameter_names.len() == 1 {
"argument"
Expand All @@ -269,18 +260,40 @@ impl FunctionDescription {
push_parameter_list(&mut msg, parameter_names);
PyTypeError::new_err(msg)
}
}

/// Add the argument name to the error message of an error which occurred during argument extraction
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.is_instance_of::<PyTypeError>(py) {
let reason = error
.value(py)
.str()
.unwrap_or_else(|_| PyString::new(py, ""));
PyTypeError::new_err(format!("argument '{}': {}", arg_name, reason))
} else {
error
#[cold]
fn missing_required_keyword_arguments(&self, keyword_outputs: &[Option<&PyAny>]) -> PyErr {
debug_assert_eq!(self.keyword_only_parameters.len(), keyword_outputs.len());

let missing_keyword_only_arguments: Vec<_> = self
.keyword_only_parameters
.iter()
.zip(keyword_outputs)
.filter_map(|(keyword_desc, out)| {
if keyword_desc.required && out.is_none() {
Some(keyword_desc.name)
} else {
None
}
})
.collect();

debug_assert!(!missing_keyword_only_arguments.is_empty());
self.missing_required_arguments("keyword", &missing_keyword_only_arguments)
}

#[cold]
fn missing_required_positional_arguments(&self, output: &[Option<&PyAny>]) -> PyErr {
let missing_positional_arguments: Vec<_> = self
.positional_parameter_names
.iter()
.take(self.required_positional_parameters)
.zip(output)
.filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None })
.collect();

debug_assert!(!missing_positional_arguments.is_empty());
self.missing_required_arguments("positional", &missing_positional_arguments)
}
}

Expand Down
1 change: 1 addition & 0 deletions src/impl_.rs
Expand Up @@ -5,6 +5,7 @@
//! breaking semver guarantees.

pub mod deprecations;
pub mod extract_argument;
pub mod freelist;
#[doc(hidden)]
pub mod frompyobject;
Expand Down
30 changes: 30 additions & 0 deletions src/impl_/extract_argument.rs
@@ -0,0 +1,30 @@
use crate::{
exceptions::PyTypeError, type_object::PyTypeObject, FromPyObject, PyAny, PyErr, PyResult,
Python,
};

#[doc(hidden)]
#[inline]
pub fn extract_argument<'py, T>(obj: &'py PyAny, arg_name: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match obj.extract() {
Ok(e) => Ok(e),
Err(e) => Err(argument_extraction_error(obj.py(), arg_name, e)),
}
}

/// Adds the argument name to the error message of an error which occurred during argument extraction.
///
/// Only modifies TypeError. (Cannot guarantee all exceptions have constructors from
/// single string.)
#[doc(hidden)]
#[cold]
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.get_type(py) == PyTypeError::type_object(py) {
PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py)))
} else {
error
}
}