diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bc8d6c4958..5dc19810d34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `PyTzInfoAccess`. [#2263](https://github.com/PyO3/pyo3/pull/2263) + ### Changed - Default to "m" ABI tag when choosing `libpython` link name for CPython 3.7 on Unix. [#2288](https://github.com/PyO3/pyo3/pull/2288) diff --git a/pytests/src/datetime.rs b/pytests/src/datetime.rs index f526ae0acfc..fc67b24718b 100644 --- a/pytests/src/datetime.rs +++ b/pytests/src/datetime.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::{ PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTuple, - PyTzInfo, + PyTzInfo, PyTzInfoAccess, }; #[pyfunction] @@ -179,6 +179,16 @@ fn datetime_from_timestamp<'p>( PyDateTime::from_timestamp(py, ts, tz) } +#[pyfunction] +fn get_datetime_tzinfo(dt: &PyDateTime) -> Option<&PyTzInfo> { + dt.get_tzinfo() +} + +#[pyfunction] +fn get_time_tzinfo(dt: &PyTime) -> Option<&PyTzInfo> { + dt.get_tzinfo() +} + #[pyclass(extends=PyTzInfo)] pub struct TzClass {} diff --git a/pytests/tests/test_datetime.py b/pytests/tests/test_datetime.py index e70504e758a..d4c1b60ec61 100644 --- a/pytests/tests/test_datetime.py +++ b/pytests/tests/test_datetime.py @@ -114,6 +114,7 @@ def test_time(args, kwargs): assert act == exp assert act.tzinfo is exp.tzinfo + assert rdt.get_time_tzinfo(act) == exp.tzinfo @given(t=st.times()) @@ -194,6 +195,7 @@ def test_datetime(args, kwargs): assert act == exp assert act.tzinfo is exp.tzinfo + assert rdt.get_datetime_tzinfo(act) == exp.tzinfo @given(dt=st.datetimes()) diff --git a/src/types/datetime.rs b/src/types/datetime.rs index c405301118e..68eaa82980d 100644 --- a/src/types/datetime.rs +++ b/src/types/datetime.rs @@ -5,6 +5,8 @@ use crate::err::PyResult; use crate::ffi; +#[cfg(PyPy)] +use crate::ffi::datetime::{PyDateTime_FromTimestamp, PyDate_FromTimestamp}; use crate::ffi::{ PyDateTime_CAPI, PyDateTime_FromTimestamp, PyDateTime_IMPORT, PyDate_FromTimestamp, }; @@ -22,6 +24,7 @@ use crate::ffi::{ PyDateTime_TIME_GET_HOUR, PyDateTime_TIME_GET_MICROSECOND, PyDateTime_TIME_GET_MINUTE, PyDateTime_TIME_GET_SECOND, }; +use crate::instance::PyNativeType; use crate::types::PyTuple; use crate::{AsPyPointer, PyAny, PyObject, Python, ToPyObject}; use std::os::raw::c_int; @@ -160,6 +163,16 @@ pub trait PyTimeAccess { fn get_fold(&self) -> bool; } +/// Trait for accessing the components of a struct containing a tzinfo. +pub trait PyTzInfoAccess { + /// Returns the tzinfo (which may be None). + /// + /// Implementations should conform to the upstream documentation: + /// + /// + fn get_tzinfo(&self) -> Option<&PyTzInfo>; +} + /// Bindings around `datetime.date` #[repr(transparent)] pub struct PyDate(PyAny); @@ -354,6 +367,19 @@ impl PyTimeAccess for PyDateTime { } } +impl PyTzInfoAccess for PyDateTime { + fn get_tzinfo(&self) -> Option<&PyTzInfo> { + let ptr = self.as_ptr() as *mut ffi::PyDateTime_DateTime; + unsafe { + if (*ptr).hastzinfo != 0 { + Some(self.py().from_borrowed_ptr((*ptr).tzinfo)) + } else { + None + } + } + } +} + /// Bindings for `datetime.time` #[repr(transparent)] pub struct PyTime(PyAny); @@ -439,6 +465,19 @@ impl PyTimeAccess for PyTime { } } +impl PyTzInfoAccess for PyTime { + fn get_tzinfo(&self) -> Option<&PyTzInfo> { + let ptr = self.as_ptr() as *mut ffi::PyDateTime_Time; + unsafe { + if (*ptr).hastzinfo != 0 { + Some(self.py().from_borrowed_ptr((*ptr).tzinfo)) + } else { + None + } + } + } +} + /// Bindings for `datetime.tzinfo` /// /// This is an abstract base class and should not be constructed directly. @@ -524,4 +563,33 @@ mod tests { assert!(b.unwrap().get_fold()); }); } + + #[cfg(not(PyPy))] + #[test] + fn test_get_tzinfo() { + crate::Python::with_gil(|py| { + use crate::conversion::ToPyObject; + use crate::types::{PyDateTime, PyTime, PyTzInfoAccess}; + + let datetime = py.import("datetime").map_err(|e| e.print(py)).unwrap(); + let timezone = datetime.getattr("timezone").unwrap(); + let utc = timezone.getattr("utc").unwrap().to_object(py); + + let dt = PyDateTime::new(py, 2018, 1, 1, 0, 0, 0, 0, Some(&utc)).unwrap(); + + assert!(dt.get_tzinfo().unwrap().eq(&utc).unwrap()); + + let dt = PyDateTime::new(py, 2018, 1, 1, 0, 0, 0, 0, None).unwrap(); + + assert!(dt.get_tzinfo().is_none()); + + let t = PyTime::new(py, 0, 0, 0, 0, Some(&utc)).unwrap(); + + assert!(t.get_tzinfo().unwrap().eq(&utc).unwrap()); + + let t = PyTime::new(py, 0, 0, 0, 0, None).unwrap(); + + assert!(t.get_tzinfo().is_none()); + }); + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 192fa80a8fa..ad79d062b9a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -11,6 +11,7 @@ pub use self::complex::PyComplex; #[cfg(not(Py_LIMITED_API))] pub use self::datetime::{ PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTzInfo, + PyTzInfoAccess, }; pub use self::dict::{IntoPyDict, PyDict}; pub use self::floatob::PyFloat;