diff --git a/elliptic-curve/src/ops.rs b/elliptic-curve/src/ops.rs index fb69b28fa..ffc4849a1 100644 --- a/elliptic-curve/src/ops.rs +++ b/elliptic-curve/src/ops.rs @@ -40,3 +40,21 @@ pub trait Reduce: Sized { Self::from_uint_reduced(UInt::from_le_byte_array(bytes)) } } + +/// Modular reduction to a non-zero output. +pub trait ReduceNonZero: Sized { + /// Perform a modular reduction, returning a field element. + fn from_uint_reduced_non_zero(n: UInt) -> Self; + + /// Interpret the given byte array as a big endian integer and perform a + /// modular reduction. + fn from_be_bytes_reduced(bytes: ByteArray) -> Self { + Self::from_uint_reduced_non_zero(UInt::from_be_byte_array(bytes)) + } + + /// Interpret the given byte array as a big endian integer and perform a + /// modular reduction. + fn from_le_bytes_reduced(bytes: ByteArray) -> Self { + Self::from_uint_reduced_non_zero(UInt::from_le_byte_array(bytes)) + } +} diff --git a/elliptic-curve/src/scalar/non_zero.rs b/elliptic-curve/src/scalar/non_zero.rs index 5e54f942d..ab5f9084a 100644 --- a/elliptic-curve/src/scalar/non_zero.rs +++ b/elliptic-curve/src/scalar/non_zero.rs @@ -3,7 +3,7 @@ use crate::{ bigint::Encoding as _, hex, - ops::Invert, + ops::{Invert, Reduce, ReduceNonZero}, rand_core::{CryptoRng, RngCore}, Curve, Error, FieldBytes, IsHigh, Result, Scalar, ScalarArithmetic, ScalarCore, SecretKey, }; @@ -12,6 +12,7 @@ use core::{ ops::{Deref, Neg}, str, }; +use crypto_bigint::{ArrayEncoding, Integer}; use ff::{Field, PrimeField}; use generic_array::GenericArray; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -198,6 +199,30 @@ where } } +impl Reduce for NonZeroScalar +where + C: Curve + ScalarArithmetic, + I: Integer + ArrayEncoding, + Scalar: ReduceNonZero, +{ + fn from_uint_reduced(n: I) -> Self { + Self::from_uint_reduced_non_zero(n) + } +} + +impl ReduceNonZero for NonZeroScalar +where + C: Curve + ScalarArithmetic, + I: Integer + ArrayEncoding, + Scalar: ReduceNonZero, +{ + fn from_uint_reduced_non_zero(n: I) -> Self { + let scalar = Scalar::::from_uint_reduced_non_zero(n); + debug_assert!(!bool::from(scalar.is_zero())); + Self::new(scalar).unwrap() + } +} + impl TryFrom<&[u8]> for NonZeroScalar where C: Curve + ScalarArithmetic,