diff --git a/src/biguint/division.rs b/src/biguint/division.rs index 030b185f..343705e1 100644 --- a/src/biguint/division.rs +++ b/src/biguint/division.rs @@ -41,6 +41,10 @@ fn div_half(rem: BigDigit, digit: BigDigit, divisor: BigDigit) -> (BigDigit, Big #[inline] pub(super) fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) { + if b == 0 { + panic!("attempt to divide by zero") + } + let mut rem = 0; if b <= big_digit::HALF { @@ -62,6 +66,10 @@ pub(super) fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) #[inline] fn rem_digit(a: &BigUint, b: BigDigit) -> BigDigit { + if b == 0 { + panic!("attempt to divide by zero") + } + let mut rem = 0; if b <= big_digit::HALF { diff --git a/tests/bigint_scalar.rs b/tests/bigint_scalar.rs index 485f2c5b..2a19fafb 100644 --- a/tests/bigint_scalar.rs +++ b/tests/bigint_scalar.rs @@ -1,8 +1,9 @@ use num_bigint::BigInt; use num_bigint::Sign::Plus; -use num_traits::{Signed, ToPrimitive, Zero}; +use num_traits::{One, Signed, ToPrimitive, Zero}; use std::ops::Neg; +use std::panic::catch_unwind; mod consts; use crate::consts::*; @@ -146,3 +147,11 @@ fn test_scalar_div_rem() { } } } + +#[test] +fn test_scalar_div_rem_zero() { + catch_unwind(|| BigInt::zero() / 0u32).unwrap_err(); + catch_unwind(|| BigInt::zero() % 0u32).unwrap_err(); + catch_unwind(|| BigInt::one() / 0u32).unwrap_err(); + catch_unwind(|| BigInt::one() % 0u32).unwrap_err(); +} diff --git a/tests/biguint_scalar.rs b/tests/biguint_scalar.rs index b6eadd9e..7c34f7ef 100644 --- a/tests/biguint_scalar.rs +++ b/tests/biguint_scalar.rs @@ -1,5 +1,7 @@ use num_bigint::BigUint; -use num_traits::{ToPrimitive, Zero}; +use num_traits::{One, ToPrimitive, Zero}; + +use std::panic::catch_unwind; mod consts; use crate::consts::*; @@ -111,3 +113,11 @@ fn test_scalar_div_rem() { } } } + +#[test] +fn test_scalar_div_rem_zero() { + catch_unwind(|| BigUint::zero() / 0u32).unwrap_err(); + catch_unwind(|| BigUint::zero() % 0u32).unwrap_err(); + catch_unwind(|| BigUint::one() / 0u32).unwrap_err(); + catch_unwind(|| BigUint::one() % 0u32).unwrap_err(); +}