Skip to content

Commit

Permalink
Some more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Psykopear committed Sep 8, 2022
1 parent 852df9c commit 4a8cac8
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions src/pyo3_extensions/chrono_tz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ use serde::{de::SeqAccess, ser::SerializeTupleStruct, Serialize};
#[derive(Clone, Debug, PartialEq, Copy, PartialOrd)]
pub struct ChronoDateTime(pub(crate) DateTime<chrono_tz::Tz>);

impl Serialize for ChronoDateTime {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// Serialize naive local datetime and timezone name separately
let mut seq = serializer.serialize_tuple_struct("ChronoDateTime", 2)?;
seq.serialize_field(&self.0.naive_local())?;
seq.serialize_field(&self.0.timezone())?;
seq.end()
}
}

/// Utilities for the wrapper DateTime
impl ChronoDateTime {
pub(crate) fn max_utc() -> Self {
Expand Down Expand Up @@ -125,7 +112,9 @@ impl FromPyObject<'_> for ChronoDateTime {
// Python assumes that naive datetimes are in the local timezone, so this
// is not like assuming it was in UTC.
// https://docs.python.org/3/library/datetime.html#datetime.datetime.astimezone
// TODO: Should we instead assume utc, and not convert?
// TODO: Should we instead assume utc, and not convert? Here we are converting
// to the local timezone of the system running the dataflow, but the date
// could come from a different place, so whatever we do is probably wrong.
warn!("Received naive datetime! Converting to utc assuming local datetime");
let utc = ob
.py()
Expand All @@ -152,49 +141,56 @@ impl FromPyObject<'_> for ChronoDateTime {
}
}

// Implement serde Deserialization manually here
pub(crate) struct ChronoDateTimeVisitor;

impl<'de> serde::de::Visitor<'de> for ChronoDateTimeVisitor {
type Value = ChronoDateTime;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a ChronoDateTime")
}

fn visit_seq<V>(self, mut seq: V) -> Result<ChronoDateTime, V::Error>
where
V: SeqAccess<'de>,
{
// Serialize naive date time and timezone name separately
let naive_local: NaiveDateTime = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let tz: chrono_tz::Tz = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
Ok(ChronoDateTime(
tz.from_local_datetime(&naive_local)
.single()
.ok_or_else(|| {
serde::de::Error::invalid_value(
serde::de::Unexpected::Other("ambiguous datetime"),
&self,
)
})?,
))
}

fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
// Implement serde (de)serialization manually here, since there is no standard
// way to serialize datetimes with IANA timezone info (yet?).
impl Serialize for ChronoDateTime {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
D: serde::Deserializer<'de>,
S: serde::Serializer,
{
deserializer.deserialize_newtype_struct("ChronoDateTime", Self)
// Serialize naive local datetime and timezone name separately
let mut seq = serializer.serialize_tuple_struct("ChronoDateTime", 2)?;
seq.serialize_field(&self.0.naive_local())?;
seq.serialize_field(&self.0.timezone())?;
seq.end()
}
}

impl<'de> serde::Deserialize<'de> for ChronoDateTime {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct ChronoDateTimeVisitor;

impl<'de> serde::de::Visitor<'de> for ChronoDateTimeVisitor {
type Value = ChronoDateTime;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a ChronoDateTime")
}

fn visit_seq<V>(self, mut seq: V) -> Result<ChronoDateTime, V::Error>
where
V: SeqAccess<'de>,
{
// Serialize naive date time and timezone name separately
let naive_local: NaiveDateTime = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let tz: chrono_tz::Tz = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
Ok(ChronoDateTime(
tz.from_local_datetime(&naive_local)
.single()
.ok_or_else(|| {
serde::de::Error::invalid_value(
serde::de::Unexpected::Other("ambiguous datetime"),
&self,
)
})?,
))
}
}

deserializer.deserialize_tuple_struct("ChronoDateTime", 2, ChronoDateTimeVisitor)
}
}
Expand All @@ -205,10 +201,13 @@ fn test_serde_chronodatetime() {
use serde_test::assert_tokens;
use serde_test::Token;

// Go from `String` to `&'static str` by leaking String's memory,
// needed for the assertion.
fn string_to_static_str(s: String) -> &'static str {
Box::leak(s.into_boxed_str())
}

// Test (de)serialization of a specific date for all possible timezones in chrono_tz
for tz in chrono_tz::TZ_VARIANTS {
// First instantiate a specific datetime in each timezone
let dt = tz.ymd(2022, 1, 1).and_hms(20, 10, 0);
Expand All @@ -227,10 +226,7 @@ fn test_serde_chronodatetime() {
Token::TupleStructEnd,
];

// Now make a ChronoDateTime
let dt = ChronoDateTime(dt);

// This does a round-trip.
assert_tokens(&dt, &expected);
assert_tokens(&ChronoDateTime(dt), &expected);
}
}

0 comments on commit 4a8cac8

Please sign in to comment.