Skip to content

Commit

Permalink
Remove uses of unwrap_u8
Browse files Browse the repository at this point in the history
Leverages the `From<Choice>` impl for `bool` where applicable instead,
which results in clearer logic which more closely matches `bool`.
  • Loading branch information
tarcieri committed May 29, 2023
1 parent 6c2233b commit d0fc864
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/constants.rs
Expand Up @@ -144,14 +144,14 @@ mod test {
let minus_one = FieldElement::MINUS_ONE;
let sqrt_m1_sq = &constants::SQRT_M1 * &constants::SQRT_M1;
assert_eq!(minus_one, sqrt_m1_sq);
assert_eq!(constants::SQRT_M1.is_negative().unwrap_u8(), 0);
assert!(bool::from(!constants::SQRT_M1.is_negative()));
}

#[test]
fn test_sqrt_constants_sign() {
let minus_one = FieldElement::MINUS_ONE;
let (was_nonzero_square, invsqrt_m1) = minus_one.invsqrt();
assert_eq!(was_nonzero_square.unwrap_u8(), 1u8);
assert!(bool::from(was_nonzero_square));
let sign_test_sqrt = &invsqrt_m1 * &constants::SQRT_M1;
assert_eq!(sign_test_sqrt, minus_one);
}
Expand Down
8 changes: 5 additions & 3 deletions src/edwards.rs
Expand Up @@ -203,7 +203,9 @@ impl CompressedEdwardsY {
let v = &(&YY * &constants::EDWARDS_D) + &Z; // v = dy²+1
let (is_valid_y_coord, mut X) = FieldElement::sqrt_ratio_i(&u, &v);

if is_valid_y_coord.unwrap_u8() != 1u8 { return None; }
if (!is_valid_y_coord).into() {
return None;
}

// FieldElement::sqrt_ratio_i always returns the nonnegative square root,
// so we negate according to the supplied sign bit.
Expand Down Expand Up @@ -466,7 +468,7 @@ impl ConstantTimeEq for EdwardsPoint {

impl PartialEq for EdwardsPoint {
fn eq(&self, other: &EdwardsPoint) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
self.ct_eq(other).into()
}
}

Expand Down Expand Up @@ -1406,7 +1408,7 @@ mod test {
Z: FieldElement::from_bytes(&two_bytes),
T: FieldElement::ZERO,
};
assert_eq!(id1.ct_eq(&id2).unwrap_u8(), 1u8);
assert!(bool::from(id1.ct_eq(&id2)));
}

/// Sanity check for conversion to precomputed points
Expand Down
24 changes: 12 additions & 12 deletions src/field.rs
Expand Up @@ -86,7 +86,7 @@ impl Eq for FieldElement {}

impl PartialEq for FieldElement {
fn eq(&self, other: &FieldElement) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
self.ct_eq(other).into()
}
}

Expand Down Expand Up @@ -187,7 +187,7 @@ impl FieldElement {
}

// acc is nonzero because we skipped zeros in inputs
assert_eq!(acc.is_zero().unwrap_u8(), 0);
assert!(bool::from(!acc.is_zero()));

// Compute the inverse of all products
acc = acc.invert();
Expand Down Expand Up @@ -406,33 +406,33 @@ mod test {

// 0/0 should return (1, 0) since u is 0
let (choice, sqrt) = FieldElement::sqrt_ratio_i(&zero, &zero);
assert_eq!(choice.unwrap_u8(), 1);
assert!(bool::from(choice));
assert_eq!(sqrt, zero);
assert_eq!(sqrt.is_negative().unwrap_u8(), 0);
assert!(bool::from(!sqrt.is_negative()));

// 1/0 should return (0, 0) since v is 0, u is nonzero
let (choice, sqrt) = FieldElement::sqrt_ratio_i(&one, &zero);
assert_eq!(choice.unwrap_u8(), 0);
assert!(bool::from(!choice));
assert_eq!(sqrt, zero);
assert_eq!(sqrt.is_negative().unwrap_u8(), 0);
assert!(bool::from(!sqrt.is_negative()));

// 2/1 is nonsquare, so we expect (0, sqrt(i*2))
let (choice, sqrt) = FieldElement::sqrt_ratio_i(&two, &one);
assert_eq!(choice.unwrap_u8(), 0);
assert!(bool::from(!choice));
assert_eq!(sqrt.square(), &two * &i);
assert_eq!(sqrt.is_negative().unwrap_u8(), 0);
assert!(bool::from(!sqrt.is_negative()));

// 4/1 is square, so we expect (1, sqrt(4))
let (choice, sqrt) = FieldElement::sqrt_ratio_i(&four, &one);
assert_eq!(choice.unwrap_u8(), 1);
assert!(bool::from(choice));
assert_eq!(sqrt.square(), four);
assert_eq!(sqrt.is_negative().unwrap_u8(), 0);
assert!(bool::from(!sqrt.is_negative()));

// 1/4 is square, so we expect (1, 1/sqrt(4))
let (choice, sqrt) = FieldElement::sqrt_ratio_i(&one, &four);
assert_eq!(choice.unwrap_u8(), 1);
assert!(bool::from(choice));
assert_eq!(&sqrt.square() * &four, one);
assert_eq!(sqrt.is_negative().unwrap_u8(), 0);
assert!(bool::from(!sqrt.is_negative()));
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/montgomery.rs
Expand Up @@ -86,7 +86,7 @@ impl ConstantTimeEq for MontgomeryPoint {

impl PartialEq for MontgomeryPoint {
fn eq(&self, other: &MontgomeryPoint) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
self.ct_eq(other).into()
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/ristretto.rs
Expand Up @@ -274,10 +274,10 @@ impl CompressedRistretto {

let s = FieldElement::from_bytes(self.as_bytes());
let s_bytes_check = s.as_bytes();
let s_encoding_is_canonical = &s_bytes_check[..].ct_eq(self.as_bytes());
let s_encoding_is_canonical = s_bytes_check[..].ct_eq(self.as_bytes());
let s_is_negative = s.is_negative();

if s_encoding_is_canonical.unwrap_u8() == 0u8 || s_is_negative.unwrap_u8() == 1u8 {
if (!s_encoding_is_canonical).into() || s_is_negative.into() {
return None;
}

Expand Down Expand Up @@ -307,9 +307,9 @@ impl CompressedRistretto {
// t == ((1+as²) sqrt(4s²/(ad(1+as²)² - (1-as²)²)))/(1-as²)
let t = &x * &y;

if ok.unwrap_u8() == 0u8
|| t.is_negative().unwrap_u8() == 1u8
|| y.is_zero().unwrap_u8() == 1u8
if (!ok).into()
|| t.is_negative().into()
|| y.is_zero().into()
{
None
} else {
Expand Down Expand Up @@ -809,7 +809,7 @@ impl Default for RistrettoPoint {

impl PartialEq for RistrettoPoint {
fn eq(&self, other: &RistrettoPoint) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
self.ct_eq(other).into()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/scalar.rs
Expand Up @@ -287,7 +287,7 @@ impl Debug for Scalar {
impl Eq for Scalar {}
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
self.ct_eq(other).into()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/traits.rs
Expand Up @@ -44,7 +44,7 @@ where
T: subtle::ConstantTimeEq + Identity,
{
fn is_identity(&self) -> bool {
self.ct_eq(&T::identity()).unwrap_u8() == 1u8
self.ct_eq(&T::identity()).into()
}
}

Expand Down

0 comments on commit d0fc864

Please sign in to comment.