Skip to content

Commit

Permalink
Merge pull request #998 from davidhewitt/pyproto-optional-return-ty
Browse files Browse the repository at this point in the history
Allow omitting return type for `#[pyproto]`
  • Loading branch information
kngwyu committed Jun 23, 2020
2 parents 0c59b05 + a9c7e12 commit a5e3d4e
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 195 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `#[pyproto]` is re-implemented without specialization. [#961](https://github.com/PyO3/pyo3/pull/961)
- `PyClassAlloc::alloc` is renamed to `PyClassAlloc::new`. [#990](https://github.com/PyO3/pyo3/pull/990)
- `#[pyproto]` methods can now have return value `T` or `PyResult<T>` (previously only `PyResult<T>` was supported). [#996](https://github.com/PyO3/pyo3/pull/996)
- `#[pyproto]` methods can now skip annotating the return type if it is `()`. [#998](https://github.com/PyO3/pyo3/pull/998)

### Removed
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)
Expand Down
2 changes: 1 addition & 1 deletion guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ Python object behavior, you need to implement the specific trait for your struct
each protocol implementation block has to be annotated with the `#[pyproto]` attribute.

All `#[pyproto]` methods which can be defined below can return `T` instead of `PyResult<T>` if the
method implementation is infallible.
method implementation is infallible. In addition, if the return type is `()`, it can be omitted altogether.

### Basic object customization

Expand Down
23 changes: 10 additions & 13 deletions pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
use crate::utils::print_err;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use syn::Token;

// TODO:
Expand Down Expand Up @@ -75,21 +75,18 @@ pub(crate) fn impl_method_proto(
sig: &mut syn::Signature,
meth: &MethodProto,
) -> TokenStream {
if let MethodProto::Free { proto, .. } = meth {
let p: syn::Path = syn::parse_str(proto).unwrap();
return quote! {
impl<'p> #p<'p> for #cls {}
};
}

let ret_ty = &*if let syn::ReturnType::Type(_, ref ty) = sig.output {
ty.clone()
} else {
panic!("fn return type is not supported")
let ret_ty = match &sig.output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
};

match *meth {
MethodProto::Free { .. } => unreachable!(),
MethodProto::Free { proto, .. } => {
let p: syn::Path = syn::parse_str(proto).unwrap();
quote! {
impl<'p> #p<'p> for #cls {}
}
}
MethodProto::Unary { proto, .. } => {
let p: syn::Path = syn::parse_str(proto).unwrap();

Expand Down
186 changes: 89 additions & 97 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ impl UnaryArithmetic {

#[pyproto]
impl PyObjectProtocol for UnaryArithmetic {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("UA({})", self.inner))
fn __repr__(&self) -> String {
format!("UA({})", self.inner)
}
}

#[pyproto]
impl PyNumberProtocol for UnaryArithmetic {
fn __neg__(&self) -> PyResult<Self> {
Ok(Self::new(-self.inner))
fn __neg__(&self) -> Self {
Self::new(-self.inner)
}

fn __pos__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner))
fn __pos__(&self) -> Self {
Self::new(self.inner)
}

fn __abs__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner.abs()))
fn __abs__(&self) -> Self {
Self::new(self.inner.abs())
}

fn __round__(&self, _ndigits: Option<u32>) -> PyResult<Self> {
Ok(Self::new(self.inner.round()))
fn __round__(&self, _ndigits: Option<u32>) -> Self {
Self::new(self.inner.round())
}
}

Expand All @@ -60,8 +60,8 @@ struct BinaryArithmetic {}

#[pyproto]
impl PyObjectProtocol for BinaryArithmetic {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("BA")
fn __repr__(&self) -> &'static str {
"BA"
}
}

Expand All @@ -72,56 +72,47 @@ struct InPlaceOperations {

#[pyproto]
impl PyObjectProtocol for InPlaceOperations {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("IPO({:?})", self.value))
fn __repr__(&self) -> String {
format!("IPO({:?})", self.value)
}
}

#[pyproto]
impl PyNumberProtocol for InPlaceOperations {
fn __iadd__(&mut self, other: u32) -> PyResult<()> {
fn __iadd__(&mut self, other: u32) {
self.value += other;
Ok(())
}

fn __isub__(&mut self, other: u32) -> PyResult<()> {
fn __isub__(&mut self, other: u32) {
self.value -= other;
Ok(())
}

fn __imul__(&mut self, other: u32) -> PyResult<()> {
fn __imul__(&mut self, other: u32) {
self.value *= other;
Ok(())
}

fn __ilshift__(&mut self, other: u32) -> PyResult<()> {
fn __ilshift__(&mut self, other: u32) {
self.value <<= other;
Ok(())
}

fn __irshift__(&mut self, other: u32) -> PyResult<()> {
fn __irshift__(&mut self, other: u32) {
self.value >>= other;
Ok(())
}

fn __iand__(&mut self, other: u32) -> PyResult<()> {
fn __iand__(&mut self, other: u32) {
self.value &= other;
Ok(())
}

fn __ixor__(&mut self, other: u32) -> PyResult<()> {
fn __ixor__(&mut self, other: u32) {
self.value ^= other;
Ok(())
}

fn __ior__(&mut self, other: u32) -> PyResult<()> {
fn __ior__(&mut self, other: u32) {
self.value |= other;
Ok(())
}

fn __ipow__(&mut self, other: u32) -> PyResult<()> {
fn __ipow__(&mut self, other: u32) {
self.value = self.value.pow(other);
Ok(())
}
}

Expand Down Expand Up @@ -151,40 +142,40 @@ fn inplace_operations() {

#[pyproto]
impl PyNumberProtocol for BinaryArithmetic {
fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mul__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} * {:?}", lhs, rhs))
fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}

fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} << {:?}", lhs, rhs))
fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}

fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} >> {:?}", lhs, rhs))
fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}

fn __and__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} & {:?}", lhs, rhs))
fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}

fn __xor__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ^ {:?}", lhs, rhs))
fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}

fn __or__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | {:?}", lhs, rhs))
fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_))
fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> String {
format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_)
}
}

Expand Down Expand Up @@ -224,40 +215,40 @@ struct RhsArithmetic {}

#[pyproto]
impl PyNumberProtocol for RhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}

fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rmul__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} * RA", other))
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}

fn __rlshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} << RA", other))
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}

fn __rrshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} >> RA", other))
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}

fn __rand__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} & RA", other))
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}

fn __rxor__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ^ RA", other))
fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}

fn __ror__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | RA", other))
fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
}

Expand Down Expand Up @@ -292,35 +283,35 @@ struct LhsAndRhsArithmetic {}

#[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}

fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}

fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}

fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}

fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?}", lhs, rhs))
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
}

#[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("BA")
fn __repr__(&self) -> &'static str {
"BA"
}
}

Expand All @@ -345,18 +336,18 @@ struct RichComparisons {}

#[pyproto]
impl PyObjectProtocol for RichComparisons {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("RC")
fn __repr__(&self) -> &'static str {
"RC"
}

fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<String> {
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> String {
match op {
CompareOp::Lt => Ok(format!("{} < {:?}", self.__repr__().unwrap(), other)),
CompareOp::Le => Ok(format!("{} <= {:?}", self.__repr__().unwrap(), other)),
CompareOp::Eq => Ok(format!("{} == {:?}", self.__repr__().unwrap(), other)),
CompareOp::Ne => Ok(format!("{} != {:?}", self.__repr__().unwrap(), other)),
CompareOp::Gt => Ok(format!("{} > {:?}", self.__repr__().unwrap(), other)),
CompareOp::Ge => Ok(format!("{} >= {:?}", self.__repr__().unwrap(), other)),
CompareOp::Lt => format!("{} < {:?}", self.__repr__(), other),
CompareOp::Le => format!("{} <= {:?}", self.__repr__(), other),
CompareOp::Eq => format!("{} == {:?}", self.__repr__(), other),
CompareOp::Ne => format!("{} != {:?}", self.__repr__(), other),
CompareOp::Gt => format!("{} > {:?}", self.__repr__(), other),
CompareOp::Ge => format!("{} >= {:?}", self.__repr__(), other),
}
}
}
Expand All @@ -366,16 +357,17 @@ struct RichComparisons2 {}

#[pyproto]
impl PyObjectProtocol for RichComparisons2 {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("RC2")
fn __repr__(&self) -> &'static str {
"RC2"
}

fn __richcmp__(&self, _other: &PyAny, op: CompareOp) -> PyResult<PyObject> {
fn __richcmp__(&self, _other: &PyAny, op: CompareOp) -> PyObject {
let gil = GILGuard::acquire();
let py = gil.python();
match op {
CompareOp::Eq => Ok(true.to_object(gil.python())),
CompareOp::Ne => Ok(false.to_object(gil.python())),
_ => Ok(gil.python().NotImplemented()),
CompareOp::Eq => true.into_py(py),
CompareOp::Ne => false.into_py(py),
_ => py.NotImplemented(),
}
}
}
Expand Down

0 comments on commit a5e3d4e

Please sign in to comment.