Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #438: Check for overflow in Duration and Timestamp processing #439

Merged
280 changes: 264 additions & 16 deletions prost-types/src/lib.rs
Expand Up @@ -26,6 +26,7 @@ pub mod compiler {
// are defined in both directions.

const NANOS_PER_SECOND: i32 = 1_000_000_000;
const NANOS_MAX: i32 = NANOS_PER_SECOND - 1;

impl Duration {
/// Normalizes the duration to a canonical format.
Expand All @@ -35,17 +36,42 @@ impl Duration {
pub fn normalize(&mut self) {
// Make sure nanos is in the range.
if self.nanos <= -NANOS_PER_SECOND || self.nanos >= NANOS_PER_SECOND {
self.seconds += (self.nanos / NANOS_PER_SECOND) as i64;
self.nanos %= NANOS_PER_SECOND;
if let Some(seconds) = self
.seconds
.checked_add((self.nanos / NANOS_PER_SECOND) as i64)
{
self.seconds = seconds;
self.nanos %= NANOS_PER_SECOND;
} else if self.nanos < 0 {
// Negative overflow! Set to the least normal value.
self.seconds = i64::MIN;
self.nanos = -NANOS_MAX;
} else {
// Positive overflow! Set to the greatest normal value.
self.seconds = i64::MAX;
self.nanos = NANOS_MAX;
}
}

// nanos should have the same sign as seconds.
if self.seconds < 0 && self.nanos > 0 {
self.seconds += 1;
self.nanos -= NANOS_PER_SECOND;
if let Some(seconds) = self.seconds.checked_add(1) {
self.seconds = seconds;
self.nanos -= NANOS_PER_SECOND;
} else {
// Positive overflow! Set to the greatest normal value.
debug_assert_eq!(self.seconds, i64::MAX);
self.nanos = NANOS_MAX;
}
} else if self.seconds > 0 && self.nanos < 0 {
self.seconds -= 1;
self.nanos += NANOS_PER_SECOND;
if let Some(seconds) = self.seconds.checked_sub(1) {
self.seconds = seconds;
self.nanos += NANOS_PER_SECOND;
} else {
// Negative overflow! Set to the least normal value.
debug_assert_eq!(self.seconds, i64::MIN);
self.nanos = -NANOS_MAX;
}
}
// TODO: should this be checked?
// debug_assert!(self.seconds >= -315_576_000_000 && self.seconds <= 315_576_000_000,
Expand Down Expand Up @@ -104,14 +130,33 @@ impl Timestamp {
pub fn normalize(&mut self) {
// Make sure nanos is in the range.
if self.nanos <= -NANOS_PER_SECOND || self.nanos >= NANOS_PER_SECOND {
self.seconds += (self.nanos / NANOS_PER_SECOND) as i64;
self.nanos %= NANOS_PER_SECOND;
if let Some(seconds) = self
.seconds
.checked_add((self.nanos / NANOS_PER_SECOND) as i64)
{
self.seconds = seconds;
self.nanos %= NANOS_PER_SECOND;
} else if self.nanos < 0 {
// Negative overflow! Set to the earliest normal value.
self.seconds = i64::MIN;
self.nanos = 0;
} else {
// Positive overflow! Set to the latest normal value.
self.seconds = i64::MAX;
self.nanos = 999_999_999;
}
}

// For Timestamp nanos should be in the range [0, 999999999].
if self.nanos < 0 {
self.seconds -= 1;
self.nanos += NANOS_PER_SECOND;
if let Some(seconds) = self.seconds.checked_sub(1) {
self.seconds = seconds;
self.nanos += NANOS_PER_SECOND;
} else {
// Negative overflow! Set to the earliest normal value.
debug_assert_eq!(self.seconds, i64::MIN);
self.nanos = 0;
}
}

// TODO: should this be checked?
Expand Down Expand Up @@ -143,17 +188,56 @@ impl From<std::time::SystemTime> for Timestamp {
}
}

/// Indicates that a [`Timestamp`] could not be converted to
/// [`SystemTime`][std::time::SystemTime] because it is out of range.
///
/// The range of times that can be represented by `SystemTime` depends on the platform.
/// All `Timestamp`s are likely representable on 64-bit Unix-like platforms, but
/// other platforms, such as Windows and 32-bit Linux, may not be able to represent
/// the full range of `Timestamp`s.
#[cfg(feature = "std")]
#[derive(Debug)]
#[non_exhaustive]
pub struct TimestampOutOfSystemRangeError {
pub timestamp: Timestamp,
}

#[cfg(feature = "std")]
impl core::fmt::Display for TimestampOutOfSystemRangeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"{:?} is not representable as a `SystemTime` because it is out of range",
self
)
}
}

#[cfg(feature = "std")]
impl std::error::Error for TimestampOutOfSystemRangeError {}

#[cfg(feature = "std")]
impl From<Timestamp> for std::time::SystemTime {
fn from(mut timestamp: Timestamp) -> std::time::SystemTime {
impl TryFrom<Timestamp> for std::time::SystemTime {
type Error = TimestampOutOfSystemRangeError;

fn try_from(mut timestamp: Timestamp) -> Result<std::time::SystemTime, Self::Error> {
let orig_timestamp = timestamp.clone();
timestamp.normalize();

let system_time = if timestamp.seconds >= 0 {
std::time::UNIX_EPOCH + time::Duration::from_secs(timestamp.seconds as u64)
std::time::UNIX_EPOCH.checked_add(time::Duration::from_secs(timestamp.seconds as u64))
} else {
std::time::UNIX_EPOCH - time::Duration::from_secs((-timestamp.seconds) as u64)
std::time::UNIX_EPOCH
.checked_sub(time::Duration::from_secs((-timestamp.seconds) as u64))
};

system_time + time::Duration::from_nanos(timestamp.nanos as u64)
let system_time = system_time.and_then(|system_time| {
system_time.checked_add(time::Duration::from_nanos(timestamp.nanos as u64))
});

system_time.ok_or(TimestampOutOfSystemRangeError {
timestamp: orig_timestamp,
})
}
}

Expand All @@ -171,7 +255,19 @@ mod tests {
fn check_system_time_roundtrip(
system_time in SystemTime::arbitrary(),
) {
prop_assert_eq!(SystemTime::from(Timestamp::from(system_time)), system_time);
prop_assert_eq!(SystemTime::try_from(Timestamp::from(system_time)).unwrap(), system_time);
}

#[test]
fn check_timestamp_roundtrip_via_system_time(
seconds in i64::arbitrary(),
nanos in i32::arbitrary(),
) {
let mut timestamp = Timestamp { seconds, nanos };
timestamp.normalize();
if let Ok(system_time) = SystemTime::try_from(timestamp.clone()) {
prop_assert_eq!(Timestamp::from(system_time), timestamp);
}
}
}

Expand Down Expand Up @@ -243,4 +339,156 @@ mod tests {
}
);
}

#[test]
fn check_duration_normalize() {
#[rustfmt::skip] // Don't mangle the table formatting.
let cases = [
// --- Table of test cases ---
// test seconds test nanos expected seconds expected nanos
(line!(), 0, 0, 0, 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen this line!() trick before - it is absolutely brilliant.

(line!(), 1, 1, 1, 1),
(line!(), -1, -1, -1, -1),
(line!(), 0, 999_999_999, 0, 999_999_999),
(line!(), 0, -999_999_999, 0, -999_999_999),
(line!(), 0, 1_000_000_000, 1, 0),
(line!(), 0, -1_000_000_000, -1, 0),
(line!(), 0, 1_000_000_001, 1, 1),
(line!(), 0, -1_000_000_001, -1, -1),
(line!(), -1, 1, 0, -999_999_999),
(line!(), 1, -1, 0, 999_999_999),
(line!(), -1, 1_000_000_000, 0, 0),
(line!(), 1, -1_000_000_000, 0, 0),
(line!(), i64::MIN , 0, i64::MIN , 0),
(line!(), i64::MIN + 1, 0, i64::MIN + 1, 0),
(line!(), i64::MIN , 1, i64::MIN + 1, -999_999_999),
(line!(), i64::MIN , 1_000_000_000, i64::MIN + 1, 0),
(line!(), i64::MIN , -1_000_000_000, i64::MIN , -999_999_999),
(line!(), i64::MIN + 1, -1_000_000_000, i64::MIN , 0),
(line!(), i64::MIN + 2, -1_000_000_000, i64::MIN + 1, 0),
(line!(), i64::MIN , -1_999_999_998, i64::MIN , -999_999_999),
(line!(), i64::MIN + 1, -1_999_999_998, i64::MIN , -999_999_998),
(line!(), i64::MIN + 2, -1_999_999_998, i64::MIN + 1, -999_999_998),
(line!(), i64::MIN , -1_999_999_999, i64::MIN , -999_999_999),
(line!(), i64::MIN + 1, -1_999_999_999, i64::MIN , -999_999_999),
(line!(), i64::MIN + 2, -1_999_999_999, i64::MIN + 1, -999_999_999),
(line!(), i64::MIN , -2_000_000_000, i64::MIN , -999_999_999),
(line!(), i64::MIN + 1, -2_000_000_000, i64::MIN , -999_999_999),
(line!(), i64::MIN + 2, -2_000_000_000, i64::MIN , 0),
(line!(), i64::MIN , -999_999_998, i64::MIN , -999_999_998),
(line!(), i64::MIN + 1, -999_999_998, i64::MIN + 1, -999_999_998),
(line!(), i64::MAX , 0, i64::MAX , 0),
(line!(), i64::MAX - 1, 0, i64::MAX - 1, 0),
(line!(), i64::MAX , -1, i64::MAX - 1, 999_999_999),
(line!(), i64::MAX , 1_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_000_000_000, i64::MAX , 0),
(line!(), i64::MAX - 2, 1_000_000_000, i64::MAX - 1, 0),
(line!(), i64::MAX , 1_999_999_998, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_999_999_998, i64::MAX , 999_999_998),
(line!(), i64::MAX - 2, 1_999_999_998, i64::MAX - 1, 999_999_998),
(line!(), i64::MAX , 1_999_999_999, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_999_999_999, i64::MAX , 999_999_999),
(line!(), i64::MAX - 2, 1_999_999_999, i64::MAX - 1, 999_999_999),
(line!(), i64::MAX , 2_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 2_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 2, 2_000_000_000, i64::MAX , 0),
(line!(), i64::MAX , 999_999_998, i64::MAX , 999_999_998),
(line!(), i64::MAX - 1, 999_999_998, i64::MAX - 1, 999_999_998),
];

for case in cases.iter() {
let mut test_duration = crate::Duration {
seconds: case.1,
nanos: case.2,
};
test_duration.normalize();

assert_eq!(
test_duration,
crate::Duration {
seconds: case.3,
nanos: case.4,
},
"test case on line {} doesn't match",
case.0,
);
}
}

#[cfg(feature = "std")]
#[test]
fn check_timestamp_normalize() {
// Make sure that `Timestamp::normalize` behaves correctly on and near overflow.
#[rustfmt::skip] // Don't mangle the table formatting.
let cases = [
// --- Table of test cases ---
// test seconds test nanos expected seconds expected nanos
(line!(), 0, 0, 0, 0),
(line!(), 1, 1, 1, 1),
(line!(), -1, -1, -2, 999_999_999),
(line!(), 0, 999_999_999, 0, 999_999_999),
(line!(), 0, -999_999_999, -1, 1),
(line!(), 0, 1_000_000_000, 1, 0),
(line!(), 0, -1_000_000_000, -1, 0),
(line!(), 0, 1_000_000_001, 1, 1),
(line!(), 0, -1_000_000_001, -2, 999_999_999),
(line!(), -1, 1, -1, 1),
(line!(), 1, -1, 0, 999_999_999),
(line!(), -1, 1_000_000_000, 0, 0),
(line!(), 1, -1_000_000_000, 0, 0),
(line!(), i64::MIN , 0, i64::MIN , 0),
(line!(), i64::MIN + 1, 0, i64::MIN + 1, 0),
(line!(), i64::MIN , 1, i64::MIN , 1),
(line!(), i64::MIN , 1_000_000_000, i64::MIN + 1, 0),
(line!(), i64::MIN , -1_000_000_000, i64::MIN , 0),
(line!(), i64::MIN + 1, -1_000_000_000, i64::MIN , 0),
(line!(), i64::MIN + 2, -1_000_000_000, i64::MIN + 1, 0),
(line!(), i64::MIN , -1_999_999_998, i64::MIN , 0),
(line!(), i64::MIN + 1, -1_999_999_998, i64::MIN , 0),
(line!(), i64::MIN + 2, -1_999_999_998, i64::MIN , 2),
(line!(), i64::MIN , -1_999_999_999, i64::MIN , 0),
(line!(), i64::MIN + 1, -1_999_999_999, i64::MIN , 0),
(line!(), i64::MIN + 2, -1_999_999_999, i64::MIN , 1),
(line!(), i64::MIN , -2_000_000_000, i64::MIN , 0),
(line!(), i64::MIN + 1, -2_000_000_000, i64::MIN , 0),
(line!(), i64::MIN + 2, -2_000_000_000, i64::MIN , 0),
(line!(), i64::MIN , -999_999_998, i64::MIN , 0),
(line!(), i64::MIN + 1, -999_999_998, i64::MIN , 2),
(line!(), i64::MAX , 0, i64::MAX , 0),
(line!(), i64::MAX - 1, 0, i64::MAX - 1, 0),
(line!(), i64::MAX , -1, i64::MAX - 1, 999_999_999),
(line!(), i64::MAX , 1_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_000_000_000, i64::MAX , 0),
(line!(), i64::MAX - 2, 1_000_000_000, i64::MAX - 1, 0),
(line!(), i64::MAX , 1_999_999_998, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_999_999_998, i64::MAX , 999_999_998),
(line!(), i64::MAX - 2, 1_999_999_998, i64::MAX - 1, 999_999_998),
(line!(), i64::MAX , 1_999_999_999, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 1_999_999_999, i64::MAX , 999_999_999),
(line!(), i64::MAX - 2, 1_999_999_999, i64::MAX - 1, 999_999_999),
(line!(), i64::MAX , 2_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 1, 2_000_000_000, i64::MAX , 999_999_999),
(line!(), i64::MAX - 2, 2_000_000_000, i64::MAX , 0),
(line!(), i64::MAX , 999_999_998, i64::MAX , 999_999_998),
(line!(), i64::MAX - 1, 999_999_998, i64::MAX - 1, 999_999_998),
];

for case in cases.iter() {
let mut test_timestamp = crate::Timestamp {
seconds: case.1,
nanos: case.2,
};
test_timestamp.normalize();

assert_eq!(
test_timestamp,
crate::Timestamp {
seconds: case.3,
nanos: case.4,
},
"test case on line {} doesn't match",
case.0,
);
}
}
}