diff --git a/arrow/src/util/decimal.rs b/arrow/src/util/decimal.rs index e5fa70d877a..67e68496326 100644 --- a/arrow/src/util/decimal.rs +++ b/arrow/src/util/decimal.rs @@ -19,7 +19,7 @@ use crate::error::{ArrowError, Result}; use num::bigint::BigInt; -use std::cmp::Ordering; +use std::cmp::{min, Ordering}; pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { /// The bit-width of the internal representation. @@ -36,6 +36,13 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { where Self: Sized, { + if precision < scale { + return Err(ArrowError::InvalidArgumentError(format!( + "Precision {} is less than scale {}", + precision, scale + ))); + } + if bytes.len() == Self::BIT_WIDTH / 8 { Ok(Self::new(precision, scale, bytes)) } else { @@ -64,6 +71,8 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { fn scale(&self) -> usize; /// Returns the string representation of the decimal. + /// If the string representation cannot be fitted with the precision of the decimal, + /// the string will be truncated. fn as_string(&self) -> String { let raw_bytes = self.raw_value(); let integer = BigInt::from_signed_bytes_le(raw_bytes); @@ -77,6 +86,11 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { if rest.len() > self.scale() { // Decimal separator is in the middle of the string + let mut bound = min(self.precision(), rest.len()); + if sign.len() == 1 { + bound += 1; + } + let value_str = value_str[0..bound].to_string(); let (whole, decimal) = value_str.split_at(value_str.len() - self.scale()); format!("{}.{}", whole, decimal) } else { @@ -99,7 +113,7 @@ pub struct Decimal128 { impl Decimal128 { /// Creates `Decimal128` from an `i128` value. - pub fn new_from_i128(precision: usize, scale: usize, value: i128) -> Self { + pub(crate) fn new_from_i128(precision: usize, scale: usize, value: i128) -> Self { Decimal128 { precision, scale, @@ -208,6 +222,13 @@ mod tests { assert_eq!(value.as_string(), "0.100"); } + #[test] + fn decimal_invalid_precision_scale() { + let bytes = 100_i128.to_le_bytes(); + let err = Decimal128::try_new_from_bytes(5, 6, &bytes); + assert!(err.is_err()); + } + #[test] fn decimal_128_from_bytes() { let mut bytes = 100_i128.to_le_bytes(); @@ -219,18 +240,24 @@ mod tests { assert_eq!(value.as_string(), "-0.01"); bytes = i128::MAX.to_le_bytes(); - let value = Decimal128::try_new_from_bytes(5, 2, &bytes).unwrap(); - assert_eq!( - value.as_string(), - "1701411834604692317316873037158841057.27" - ); + let value = Decimal128::try_new_from_bytes(38, 2, &bytes).unwrap(); + assert_eq!(value.as_string(), "170141183460469231731687303715884105.72"); bytes = i128::MIN.to_le_bytes(); - let value = Decimal128::try_new_from_bytes(5, 2, &bytes).unwrap(); + let value = Decimal128::try_new_from_bytes(38, 2, &bytes).unwrap(); assert_eq!( value.as_string(), - "-1701411834604692317316873037158841057.28" + "-170141183460469231731687303715884105.72" ); + + // Truncated + bytes = 12345_i128.to_le_bytes(); + let value = Decimal128::try_new_from_bytes(3, 2, &bytes).unwrap(); + assert_eq!(value.as_string(), "1.23"); + + bytes = (-12345_i128).to_le_bytes(); + let value = Decimal128::try_new_from_bytes(3, 2, &bytes).unwrap(); + assert_eq!(value.as_string(), "-1.23"); } #[test] @@ -241,7 +268,7 @@ mod tests { assert_eq!(value.as_string(), "1.00"); bytes[0..16].clone_from_slice(&i128::MAX.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(5, 4, &bytes).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 4, &bytes).unwrap(); assert_eq!( value.as_string(), "17014118346046923173168730371588410.5727"