diff --git a/p384/src/arithmetic/field.rs b/p384/src/arithmetic/field.rs index a10172df..03eef3eb 100644 --- a/p384/src/arithmetic/field.rs +++ b/p384/src/arithmetic/field.rs @@ -28,7 +28,7 @@ use self::field_impl::*; use crate::FieldBytes; use core::ops::{AddAssign, MulAssign, Neg, SubAssign}; use elliptic_curve::{ - bigint::{ArrayEncoding, Encoding, Integer, U384}, + bigint::{self, ArrayEncoding, Encoding, Integer, Limb, U384}, subtle::{Choice, ConstantTimeEq, ConstantTimeLess, CtOption}, }; @@ -52,10 +52,7 @@ impl_field_element!( fiat_p384_sub, fiat_p384_mul, fiat_p384_opp, - fiat_p384_square, - fiat_p384_divstep_precomp, - fiat_p384_divstep, - fiat_p384_msat, + fiat_p384_square ); impl FieldElement { @@ -72,6 +69,23 @@ impl FieldElement { self.to_be_bytes() } + /// Compute [`FieldElement`] inversion: `1 / self`. + pub fn invert(&self) -> CtOption { + let ret = impl_field_invert!( + self.to_canonical().to_uint_array(), + Self::ONE.0.to_uint_array(), + Limb::BIT_SIZE, + bigint::nlimbs!(U384::BIT_SIZE), + fiat_p384_mul, + fiat_p384_opp, + fiat_p384_divstep_precomp, + fiat_p384_divstep, + fiat_p384_msat, + fiat_p384_selectznz, + ); + CtOption::new(Self(ret.into()), !self.is_zero()) + } + /// Returns the square root of self mod p, or `None` if no square root /// exists. pub fn sqrt(&self) -> CtOption { diff --git a/p384/src/arithmetic/macros.rs b/p384/src/arithmetic/macros.rs index f4d28dda..57b8bb47 100644 --- a/p384/src/arithmetic/macros.rs +++ b/p384/src/arithmetic/macros.rs @@ -20,10 +20,13 @@ /// - `pub fn is_odd` /// - `pub fn is_zero` /// - `pub fn double` -/// - `pub fn invert` /// -/// NOTE: field implementations must provide their own inherent `pub fn sqrt` -/// method in order for the code generated by this macro to compile. +/// NOTE: field implementations must provide their own inherent impls of +/// the following methods in order for the code generated by this macro to +/// compile: +/// +/// - `pub fn invert` +/// - `pub fn sqrt` /// /// # Trait impls /// - `AsRef<$arr>` @@ -58,10 +61,7 @@ macro_rules! impl_field_element { $sub:ident, $mul:ident, $neg:ident, - $square:ident, - $divstep_precomp:ident, - $divstep:ident, - $msat:ident, + $square:ident ) => { impl $fe { /// Zero element. @@ -237,55 +237,6 @@ macro_rules! impl_field_element { Self(<$uint>::from_uint_array($neg(self.0.as_uint_array()))) } - /// Compute [` - #[doc = stringify!($fe)] - /// `] inversion: `1 / self`. - pub fn invert(&self) -> ::elliptic_curve::subtle::CtOption { - use ::elliptic_curve::{ - bigint::{Limb, LimbUInt as Word}, - subtle::ConditionallySelectable, - }; - - const LIMBS: usize = ::elliptic_curve::bigint::nlimbs!(<$uint>::BIT_SIZE); - const ITERATIONS: usize = (49 * <$uint>::BIT_SIZE + 57) / 17; - - let mut d = 1; - let mut f = $msat(); - let mut g: [Word; LIMBS + 1] = Default::default(); - g[..LIMBS].copy_from_slice(&$from_mont(self.as_ref())); - let mut r = <$arr>::from(Self::ONE.0); - let mut v = <$arr>::default(); - let mut i: usize = 0; - - while i < ITERATIONS - ITERATIONS % 2 { - let (out1, out2, out3, out4, out5) = $divstep(d, &f, &g, &v, &r); - let (out1, out2, out3, out4, out5) = $divstep(out1, &out2, &out3, &out4, &out5); - d = out1; - f = out2; - g = out3; - v = out4; - r = out5; - i += 2; - } - - if ITERATIONS % 2 != 0 { - let (_out1, out2, _out3, out4, _out5) = $divstep(d, &f, &g, &v, &r); - v = out4; - f = out2; - } - - let v_opp = <$uint>::from($neg(&v)); - let v = <$uint>::from(v); - - let s = ::elliptic_curve::subtle::Choice::from( - ((f[f.len() - 1] >> Limb::BIT_SIZE - 1) & 1) as u8, - ); - - let v = <$uint>::conditional_select(&v, &v_opp, s); - let ret = $mul(v.as_ref(), &$divstep_precomp()); - ::elliptic_curve::subtle::CtOption::new(Self(ret.into()), !self.is_zero()) - } - /// Compute modular square. #[must_use] pub const fn square(&self) -> Self { @@ -472,3 +423,51 @@ macro_rules! impl_field_op { } }; } + +/// Implement field element inversion. +macro_rules! impl_field_invert { + ( + $a:expr, + $one:expr, + $word_bits:expr, + $nlimbs:expr, + $mul:ident, + $neg:ident, + $divstep_precomp:ident, + $divstep:ident, + $msat:ident, + $selectznz:ident, + ) => {{ + const ITERATIONS: usize = (49 * $nlimbs * $word_bits + 57) / 17; + + let mut d = 1; + let mut f = $msat(); + let mut g = [0; $nlimbs + 1]; + let mut v = Default::default(); + let mut r = $one; + let mut i = 0; + + g[..$nlimbs].copy_from_slice($a.as_ref()); + + while i < ITERATIONS - ITERATIONS % 2 { + let (out1, out2, out3, out4, out5) = $divstep(d, &f, &g, &v, &r); + let (out1, out2, out3, out4, out5) = $divstep(out1, &out2, &out3, &out4, &out5); + d = out1; + f = out2; + g = out3; + v = out4; + r = out5; + i += 2; + } + + if ITERATIONS % 2 != 0 { + let (_out1, out2, _out3, out4, _out5) = $divstep(d, &f, &g, &v, &r); + v = out4; + f = out2; + } + + let s = ((f[f.len() - 1] >> $word_bits - 1) & 1) as u8; + let v = $selectznz(s, &v, &$neg(&v)); + $mul(&v, &$divstep_precomp()) + }}; +} diff --git a/p384/src/arithmetic/scalar.rs b/p384/src/arithmetic/scalar.rs index aa0b0ac2..50d850a4 100644 --- a/p384/src/arithmetic/scalar.rs +++ b/p384/src/arithmetic/scalar.rs @@ -15,7 +15,7 @@ use self::scalar_impl::*; use crate::{FieldBytes, NistP384, SecretKey, U384}; use core::ops::{AddAssign, MulAssign, Neg, SubAssign}; use elliptic_curve::{ - bigint::{ArrayEncoding, Encoding, Integer, Limb}, + bigint::{self, ArrayEncoding, Encoding, Integer, Limb}, ff::PrimeField, ops::Reduce, subtle::{ @@ -49,16 +49,30 @@ impl_field_element!( fiat_p384_scalar_sub, fiat_p384_scalar_mul, fiat_p384_scalar_opp, - fiat_p384_scalar_square, - fiat_p384_scalar_divstep_precomp, - fiat_p384_scalar_divstep, - fiat_p384_scalar_msat, + fiat_p384_scalar_square ); impl Scalar { /// `2^s` root of unity. pub const ROOT_OF_UNITY: Self = Self::from_be_hex("ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52972"); + /// Compute [`Scalar`] inversion: `1 / self`. + pub fn invert(&self) -> CtOption { + let ret = impl_field_invert!( + self.to_canonical().to_uint_array(), + Self::ONE.0.to_uint_array(), + Limb::BIT_SIZE, + bigint::nlimbs!(U384::BIT_SIZE), + fiat_p384_scalar_mul, + fiat_p384_scalar_opp, + fiat_p384_scalar_divstep_precomp, + fiat_p384_scalar_divstep, + fiat_p384_scalar_msat, + fiat_p384_scalar_selectznz, + ); + CtOption::new(Self(ret.into()), !self.is_zero()) + } + /// Compute modular square root. pub fn sqrt(&self) -> CtOption { // p mod 4 = 3 -> compute sqrt(x) using x^((p+1)/4) =