diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs index 0b586bbcfb5..444e3121c9f 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint.rs @@ -147,18 +147,31 @@ impl i256 { /// Performs wrapping multiplication #[inline] pub fn wrapping_mul(self, other: Self) -> Self { - let l = BigInt::from_signed_bytes_le(&self.to_le_bytes()); - let r = BigInt::from_signed_bytes_le(&other.to_le_bytes()); - Self::from_bigint_with_overflow(l * r).0 + let (low, high) = mulx(self.low, other.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = self.high.wrapping_mul(other.low as i128); + let lh = (self.low as i128).wrapping_mul(other.high); + + Self { + low, + high: (high as i128).wrapping_add(hl).wrapping_add(lh), + } } /// Performs checked multiplication #[inline] pub fn checked_mul(self, other: Self) -> Option { - let l = BigInt::from_signed_bytes_le(&self.to_le_bytes()); - let r = BigInt::from_signed_bytes_le(&other.to_le_bytes()); - let (val, overflow) = Self::from_bigint_with_overflow(l * r); - (!overflow).then(|| val) + let (low, high) = mulx(self.low, other.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = self.high.checked_mul(other.low as i128)?; + let lh = (self.low as i128).checked_mul(other.high)?; + + Some(Self { + low, + high: (high as i128).checked_add(hl)?.checked_add(lh)?, + }) } /// Performs wrapping division @@ -200,6 +213,40 @@ impl i256 { } } +#[inline] +pub fn mulx(a: u128, b: u128) -> (u128, u128) { + let split = |a: u128| (a & (u64::MAX as u128), a >> 64); + + const MASK: u128 = u64::MAX as _; + + let (a_low, a_high) = split(a); + let (b_low, b_high) = split(b); + + // Carry stores the upper 64-bits of low and lower 64-bits of high + let (mut low, mut carry) = split(a_low * b_low); + carry += a_high * b_low; + + // Update low and high with corresponding parts of carry + low += carry << 64; + let mut high = carry >> 64; + + // Update carry with overflow from low + carry = low >> 64; + low &= MASK; + + // Perform multiply including overflow from low + carry += b_high * a_low; + + // Update low and high with values from carry + low += carry << 64; + high += carry >> 64; + + // Perform 4th multiplication + high += a_high * b_high; + + (low, high) +} + #[cfg(test)] mod tests { use super::*; @@ -262,6 +309,18 @@ mod tests { true => assert!(checked.is_none()), false => assert_eq!(checked.unwrap(), actual), } + + // Multiplication + let actual = il.wrapping_mul(ir); + let (expected, overflow) = + i256::from_bigint_with_overflow(bl.clone() * br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_mul(ir); + match overflow { + true => assert!(checked.is_none()), + false => assert_eq!(checked.unwrap(), actual), + } } } }