Skip to content

Commit

Permalink
opt: make argument extraction code smaller
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Dec 24, 2021
1 parent 947055d commit b0be6de
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 70 deletions.
12 changes: 6 additions & 6 deletions pyo3-macros-backend/src/params.rs
Expand Up @@ -213,8 +213,9 @@ fn impl_arg_param(

let ty = arg.ty;
let name = arg.name;
let name_str = name.to_string();
let transform_error = quote! {
|e| _pyo3::derive_utils::argument_extraction_error(#py, stringify!(#name), e)
|e| _pyo3::impl_::extract_argument::argument_extraction_error(#py, #name_str, e)
};

if is_args(&spec.attrs, name) {
Expand All @@ -223,17 +224,16 @@ fn impl_arg_param(
arg.name.span() => "args cannot be optional"
);
return Ok(quote_arg_span! {
let #arg_name = _args.unwrap().extract().map_err(#transform_error)?;
let #arg_name = _pyo3::impl_::extract_argument::extract_argument(_args.unwrap(), #name_str)?;
});
} else if is_kwargs(&spec.attrs, name) {
ensure_spanned!(
arg.optional.is_some(),
arg.name.span() => "kwargs must be Option<_>"
);
return Ok(quote_arg_span! {
let #arg_name = _kwargs.map(|kwargs| kwargs.extract())
.transpose()
.map_err(#transform_error)?;
let #arg_name = _kwargs.map(|kwargs| _pyo3::impl_::extract_argument::extract_argument(kwargs, #name_str))
.transpose()?;
});
}

Expand All @@ -243,7 +243,7 @@ fn impl_arg_param(
let extract = if let Some(FromPyWithAttribute(expr_path)) = &arg.attrs.from_py_with {
quote_arg_span! { #expr_path(_obj).map_err(#transform_error) }
} else {
quote_arg_span! { _obj.extract().map_err(#transform_error) }
quote_arg_span! { _pyo3::impl_::extract_argument::extract_argument(_obj, #name_str) }
};

let arg_value_or_default = match (spec.default_value(name), arg.optional.is_some()) {
Expand Down
3 changes: 3 additions & 0 deletions pytests/pyo3-benchmarks/tox.ini
Expand Up @@ -3,3 +3,6 @@ usedevelop = True
description = Run the unit tests under {basepython}
deps = -rrequirements-dev.txt
commands = pytest --benchmark-sort=name {posargs}
# Use recreate so that tox always rebuilds, otherwise changes to Rust are not
# picked up.
recreate = True
141 changes: 77 additions & 64 deletions src/derive_utils.rs
Expand Up @@ -102,100 +102,77 @@ impl FunctionDescription {
varkeywords
}
(Some(kwargs), false) => {
self.extract_keyword_arguments(kwargs, output, |name, _| {
Err(self.unexpected_keyword_argument(name))
})?;
self.extract_keyword_arguments(
kwargs,
output,
#[cold]
|name, _| Err(self.unexpected_keyword_argument(name)),
)?;
None
}
(None, _) => None,
};

// Check that there's sufficient positional arguments once keyword arguments are specified
if args_provided < self.required_positional_parameters {
let missing_positional_arguments: Vec<_> = self
.positional_parameter_names
.iter()
.take(self.required_positional_parameters)
.zip(output.iter())
.filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None })
.collect();
if !missing_positional_arguments.is_empty() {
return Err(
self.missing_required_arguments("positional", &missing_positional_arguments)
);
for out in &output[..self.required_positional_parameters] {
if out.is_none() {
return Err(self.missing_required_positional_arguments(output));
}
}
}

// Check no missing required keyword arguments
let missing_keyword_only_arguments: Vec<_> = self
.keyword_only_parameters
.iter()
.zip(&output[num_positional_parameters..])
.filter_map(|(keyword_desc, out)| {
if keyword_desc.required && out.is_none() {
Some(keyword_desc.name)
} else {
None
}
})
.collect();

if !missing_keyword_only_arguments.is_empty() {
return Err(self.missing_required_arguments("keyword", &missing_keyword_only_arguments));
let keyword_output = &output[num_positional_parameters..];
for (param, out) in self.keyword_only_parameters.iter().zip(keyword_output) {
if param.required && out.is_none() {
return Err(self.missing_required_keyword_arguments(keyword_output));
}
}

Ok((varargs, varkeywords))
}

#[inline]
fn extract_keyword_arguments<'p>(
&self,
kwargs: impl Iterator<Item = (&'p PyAny, &'p PyAny)>,
output: &mut [Option<&'p PyAny>],
mut unexpected_keyword_handler: impl FnMut(&'p PyAny, &'p PyAny) -> PyResult<()>,
) -> PyResult<()> {
let (args_output, kwargs_output) =
output.split_at_mut(self.positional_parameter_names.len());
let positional_args_count = self.positional_parameter_names.len();
let mut positional_only_keyword_arguments = Vec::new();
for (kwarg_name, value) in kwargs {
let utf8_string = match kwarg_name.downcast::<PyString>()?.to_str() {
Ok(utf8_string) => utf8_string,
'for_each_kwarg: for (kwarg_name_py, value) in kwargs {
let kwarg_name = match kwarg_name_py.downcast::<PyString>()?.to_str() {
Ok(kwarg_name) => kwarg_name,
// This keyword is not a UTF8 string: all PyO3 argument names are guaranteed to be
// UTF8 by construction.
Err(_) => {
unexpected_keyword_handler(kwarg_name, value)?;
unexpected_keyword_handler(kwarg_name_py, value)?;
continue;
}
};

// Compare the keyword name against each parameter in turn. This is exactly the same method
// which CPython uses to map keyword names. Although it's O(num_parameters), the number of
// parameters is expected to be small so it's not worth constructing a mapping.
if let Some(i) = self
.keyword_only_parameters
.iter()
.position(|param| utf8_string == param.name)
{
kwargs_output[i] = Some(value);
continue;
for (i, param) in self.keyword_only_parameters.iter().enumerate() {
if param.name == kwarg_name {
output[positional_args_count + i] = Some(value);
continue 'for_each_kwarg;
}
}

// Repeat for positional parameters
if let Some((i, param)) = self
.positional_parameter_names
.iter()
.enumerate()
.find(|&(_, param)| utf8_string == *param)
{
if let Some(i) = self.find_keyword_parameter_in_positionals(kwarg_name) {
if i < self.positional_only_parameters {
positional_only_keyword_arguments.push(*param);
} else if args_output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(param));
positional_only_keyword_arguments.push(kwarg_name);
} else if output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
}

unexpected_keyword_handler(kwarg_name, value)?;
unexpected_keyword_handler(kwarg_name_py, value)?;
}

if positional_only_keyword_arguments.is_empty() {
Expand All @@ -205,6 +182,16 @@ impl FunctionDescription {
}
}

fn find_keyword_parameter_in_positionals(&self, kwarg_name: &str) -> Option<usize> {
for (i, param_name) in self.positional_parameter_names.iter().enumerate() {
if *param_name == kwarg_name {
return Some(i);
}
}
None
}

#[cold]
fn too_many_positional_arguments(&self, args_provided: usize) -> PyErr {
let was = if args_provided == 1 { "was" } else { "were" };
let msg = if self.required_positional_parameters != self.positional_parameter_names.len() {
Expand All @@ -228,6 +215,7 @@ impl FunctionDescription {
PyTypeError::new_err(msg)
}

#[cold]
fn multiple_values_for_argument(&self, argument: &str) -> PyErr {
PyTypeError::new_err(format!(
"{} got multiple values for argument '{}'",
Expand All @@ -236,6 +224,7 @@ impl FunctionDescription {
))
}

#[cold]
fn unexpected_keyword_argument(&self, argument: &PyAny) -> PyErr {
PyTypeError::new_err(format!(
"{} got an unexpected keyword argument '{}'",
Expand All @@ -244,6 +233,7 @@ impl FunctionDescription {
))
}

#[cold]
fn positional_only_keyword_arguments(&self, parameter_names: &[&str]) -> PyErr {
let mut msg = format!(
"{} got some positional-only arguments passed as keyword arguments: ",
Expand All @@ -253,6 +243,7 @@ impl FunctionDescription {
PyTypeError::new_err(msg)
}

#[cold]
fn missing_required_arguments(&self, argument_type: &str, parameter_names: &[&str]) -> PyErr {
let arguments = if parameter_names.len() == 1 {
"argument"
Expand All @@ -269,18 +260,40 @@ impl FunctionDescription {
push_parameter_list(&mut msg, parameter_names);
PyTypeError::new_err(msg)
}
}

/// Add the argument name to the error message of an error which occurred during argument extraction
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.is_instance_of::<PyTypeError>(py) {
let reason = error
.value(py)
.str()
.unwrap_or_else(|_| PyString::new(py, ""));
PyTypeError::new_err(format!("argument '{}': {}", arg_name, reason))
} else {
error
#[cold]
fn missing_required_keyword_arguments(&self, keyword_outputs: &[Option<&PyAny>]) -> PyErr {
debug_assert_eq!(self.keyword_only_parameters.len(), keyword_outputs.len());

let missing_keyword_only_arguments: Vec<_> = self
.keyword_only_parameters
.iter()
.zip(keyword_outputs)
.filter_map(|(keyword_desc, out)| {
if keyword_desc.required && out.is_none() {
Some(keyword_desc.name)
} else {
None
}
})
.collect();

debug_assert!(!missing_keyword_only_arguments.is_empty());
self.missing_required_arguments("keyword", &missing_keyword_only_arguments)
}

#[cold]
fn missing_required_positional_arguments(&self, output: &[Option<&PyAny>]) -> PyErr {
let missing_positional_arguments: Vec<_> = self
.positional_parameter_names
.iter()
.take(self.required_positional_parameters)
.zip(output)
.filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None })
.collect();

debug_assert!(!missing_positional_arguments.is_empty());
self.missing_required_arguments("positional", &missing_positional_arguments)
}
}

Expand Down
1 change: 1 addition & 0 deletions src/impl_.rs
Expand Up @@ -5,6 +5,7 @@
//! breaking semver guarantees.

pub mod deprecations;
pub mod extract_argument;
pub mod freelist;
#[doc(hidden)]
pub mod frompyobject;
30 changes: 30 additions & 0 deletions src/impl_/extract_argument.rs
@@ -0,0 +1,30 @@
use crate::{
exceptions::PyTypeError, type_object::PyTypeObject, FromPyObject, PyAny, PyErr, PyResult,
Python,
};

#[doc(hidden)]
#[inline]
pub fn extract_argument<'py, T>(obj: &'py PyAny, arg_name: &str) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match obj.extract() {
Ok(e) => Ok(e),
Err(e) => Err(argument_extraction_error(obj.py(), arg_name, e)),
}
}

/// Adds the argument name to the error message of an error which occurred during argument extraction.
///
/// Only modifies TypeError. (Cannot guarantee all exceptions have constructors from
/// single string.)
#[doc(hidden)]
#[cold]
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.get_type(py) == PyTypeError::type_object(py) {
PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py)))
} else {
error
}
}

0 comments on commit b0be6de

Please sign in to comment.