Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Secp256r1 / P256 verification in solidity #4881

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/odd-lobsters-wash.md
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`P256`: Add a library for verification/recovery of Secp256r1 (Aka P256) signatures.
2 changes: 2 additions & 0 deletions .github/workflows/checks.yml
Expand Up @@ -66,6 +66,8 @@ jobs:
run: bash scripts/upgradeable/transpile.sh
- name: Run tests
run: npm run test
env:
UNLIMITED: true
- name: Check linearisation of the inheritance graph
run: npm run test:inheritance
- name: Check storage layout
Expand Down
1 change: 1 addition & 0 deletions contracts/mocks/Stateless.sol
Expand Up @@ -24,6 +24,7 @@ import {ERC721Holder} from "../token/ERC721/utils/ERC721Holder.sol";
import {Math} from "../utils/math/Math.sol";
import {MerkleProof} from "../utils/cryptography/MerkleProof.sol";
import {MessageHashUtils} from "../utils/cryptography/MessageHashUtils.sol";
import {P256} from "../utils/cryptography/P256.sol";
import {SafeCast} from "../utils/math/SafeCast.sol";
import {SafeERC20} from "../token/ERC20/utils/SafeERC20.sol";
import {ShortStrings} from "../utils/ShortStrings.sol";
Expand Down
316 changes: 316 additions & 0 deletions contracts/utils/cryptography/P256.sol
@@ -0,0 +1,316 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.20;

import {Math} from "../math/Math.sol";

/**
* @dev Implementation of secp256r1 verification and recovery functions.
*
* Based on
* - https://github.com/itsobvioustech/A-passkeys-wallet/blob/main/src/Secp256r1.sol
* Which is heavily inspired from
* - https://github.com/maxrobot/elliptic-solidity/blob/master/contracts/Secp256r1.sol
* - https://github.com/tdrerup/elliptic-curve-solidity/blob/master/contracts/curves/EllipticCurve.sol
*/
library P256 {
struct JPoint {
uint256 x;
uint256 y;
uint256 z;
}

/// @dev Generator (x component)
uint256 internal constant GX = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296;
/// @dev Generator (y component)
uint256 internal constant GY = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5;
/// @dev P (size of the field)
uint256 internal constant P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF;
/// @dev N (order of G)
uint256 internal constant N = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551;
/// @dev A parameter of the weierstrass equation
uint256 internal constant A = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC;
/// @dev B parameter of the weierstrass equation
uint256 internal constant B = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B;
Comment on lines +22 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These match with the parameters listed here


uint256 private constant P2 = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFD;
uint256 private constant N2 = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC63254F;
uint256 private constant P1DIV4 = 0x3fffffffc0000000400000000000000000000000400000000000000000000000;

/**
* @dev signature verification
* @param h - hashed message
* @param r - signature half R
* @param s - signature half S
* @param qx - public key coordinate X
* @param qy - public key coordinate Y
*/
function verify(uint256 h, uint256 r, uint256 s, uint256 qx, uint256 qy) internal view returns (bool) {
if (r == 0 || r >= N || s == 0 || s >= N || !isOnCurve(qx, qy)) return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most sources mention that if r == 0 or s == 0 during signature generation then the signature should be regenerated with a different nonce. However, I haven't found exactly why is it critical to check for both.

I get that operating with r == 0 or s == 0 breaks point addition, but I haven't found how that could be used maliciously. Would appreciate if you point me out to a source if you have one.

Copy link
Collaborator Author

@Amxx Amxx Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not generating a signature, we are verifying one. As you saif, it should not be generated that way, and if it is, we should reject it.

I'm not exactly sure how that could be used maliciously. I think the point is more like: we know this cannot possibly be a valid input, so we should reject it. We should count on the caller doing the sanity check.

Can we prove that is r (or s) is 0, then the function will return false without reverting? Maybe.

  • If s = 0, we get w = 0, which is not actually an inverse. That gives u1 = 0.
  • If r = 0, we get u2 = 0.

In _jMultShamir, it is unclear to me that any of there value being 0 can be treated in a specific way. If both are 0, then the end point is (0, 0) and we might get "true".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested that. If you remove the zero check, then this test passes:

    /// forge-config: default.fuzz.runs = 512
    function testVerifyZero(uint256 seed, bytes32 digest) public {
        uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1);

        (uint256 x, uint256 y) = P256.getPublicKey(privateKey);
        assertTrue(P256.verify(uint256(digest), 0, 0, x, y));
    }

Said otherwize: without this check r=0, s=0 would be a valid signature that would be verified for any h and any qx, qy.


JPoint[16] memory points = _preComputeJacobianPoints(qx, qy);
uint256 w = _invModN(s);
uint256 u1 = mulmod(h, w, N);
uint256 u2 = mulmod(r, w, N);
(uint256 x, ) = _jMultShamir(points, u1, u2);
return (x == r);
}

/**
* @dev public key recovery
* @param h - hashed message
* @param v - signature recovery param
* @param r - signature half R
* @param s - signature half S
*/
function recovery(uint256 h, uint8 v, uint256 r, uint256 s) internal view returns (uint256, uint256) {
if (r == 0 || r >= N || s == 0 || s >= N || v > 1) return (0, 0);

uint256 rx = r;
uint256 ry2 = addmod(mulmod(addmod(mulmod(rx, rx, P), A, P), rx, P), B, P); // weierstrass equation y² = x³ + a.x + b
uint256 ry = Math.modExp(ry2, P1DIV4, P); // This formula for sqrt work because P ≡ 3 (mod 4)
if (mulmod(ry, ry, P) != ry2) return (0, 0); // Sanity check
if (ry % 2 != v % 2) ry = P - ry;

JPoint[16] memory points = _preComputeJacobianPoints(rx, ry);
uint256 w = _invModN(r);
uint256 u1 = mulmod(N - (h % N), w, N);
uint256 u2 = mulmod(s, w, N);
(uint256 x, uint256 y) = _jMultShamir(points, u1, u2);
return (x, y);
}

/**
* @dev address recovery
* @param h - hashed message
* @param v - signature recovery param
* @param r - signature half R
* @param s - signature half S
*/
function recoveryAddress(uint256 h, uint8 v, uint256 r, uint256 s) internal view returns (address) {
(uint256 qx, uint256 qy) = recovery(h, v, r, s);
return getAddress(qx, qy);
}

/**
* @dev derivate public key
* @param privateKey - private key
*/
function getPublicKey(uint256 privateKey) internal view returns (uint256, uint256) {
(uint256 x, uint256 y, uint256 z) = _jMult(GX, GY, 1, privateKey);
return _affineFromJacobian(x, y, z);
}

/**
* @dev Hash public key into an address
* @param qx - public key coordinate X
* @param qy - public key coordinate Y
*/
function getAddress(uint256 qx, uint256 qy) internal pure returns (address result) {
/// @solidity memory-safe-assembly
assembly {
mstore(0x00, qx)
mstore(0x20, qy)
result := keccak256(0x00, 0x40)
}
}

/**
* @dev check if a point is on the curve.
*/
function isOnCurve(uint256 x, uint256 y) internal pure returns (bool result) {
/// @solidity memory-safe-assembly
assembly {
let p := P
let lhs := mulmod(y, y, p)
let rhs := addmod(mulmod(addmod(mulmod(x, x, p), A, p), x, p), B, p)
result := eq(lhs, rhs)
}
}

/**
* @dev Reduce from jacobian to affine coordinates
* @param jx - jacobian coordinate x
* @param jy - jacobian coordinate y
* @param jz - jacobian coordinate z
* @return ax - affine coordinate x
* @return ay - affine coordinate y
*/
function _affineFromJacobian(uint256 jx, uint256 jy, uint256 jz) private view returns (uint256 ax, uint256 ay) {
if (jz == 0) return (0, 0);
uint256 zinv = _invModP(jz);
uint256 zzinv = mulmod(zinv, zinv, P);
uint256 zzzinv = mulmod(zzinv, zinv, P);
ax = mulmod(jx, zzinv, P);
ay = mulmod(jy, zzzinv, P);
}

/**
* @dev Point addition on the jacobian coordinates
* https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates
*/
function _jAdd(
uint256 x1,
uint256 y1,
uint256 z1,
uint256 x2,
uint256 y2,
uint256 z2
) private pure returns (uint256 x3, uint256 y3, uint256 z3) {
if (z1 == 0) {
return (x2, y2, z2);
}
if (z2 == 0) {
return (x1, y1, z1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because adding some coordinates to the point at infinity results in the same coordinates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record:

Here the check is so that (0,0,0) is a "neutral element" for the jacobian addition. We could make the check as

        if (x1 == 0 && y1 == 0 && z1 == 0) {
            return (x2, y2, z2);
        }
        if (x2 == 0 && y2 == 0 && z2 == 0) {
            return (x1, y1, z1);
        }

but that would be more expensive (and that function is called 140 times when you verify)

/// @solidity memory-safe-assembly
assembly {
let p := P
let zz1 := mulmod(z1, z1, p) // zz1 = z1²
let zz2 := mulmod(z2, z2, p) // zz2 = z2²
let u1 := mulmod(x1, zz2, p) // u1 = x1*z2²
let u2 := mulmod(x2, zz1, p) // u2 = x2*z1²
let s1 := mulmod(y1, mulmod(zz2, z2, p), p) // s1 = y1*z2³
let s2 := mulmod(y2, mulmod(zz1, z1, p), p) // s2 = y2*z1³
let h := addmod(u2, sub(p, u1), p) // h = u2-u1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the _jDouble function, I see a couple of operations not followed from the reference here:

...
 if (U1 == U2)
   if (S1 != S2)
     return POINT_AT_INFINITY
   else 
     return POINT_DOUBLE(X1, Y1, Z1)
...

I want to make sure I understand why this is being ignored

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this is a similar case to #4881 (comment), but I'm not 100% sure, so lets explore that

Copy link
Collaborator Author

@Amxx Amxx Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if u1=u2, then h = u2-u1 = 0, and so h*z1*z2 = 0 ... so it appears we return a point at infinity

also, if u1=u2 then h = 0 and then:

  • x' = r²-h³-2*u1*h² = r²-0-0 = r²
  • y' = r*(u1*h²-x')-s1*h³ = r*(0-x')-0 = r*(-r²) = -r³

with r = s2-s1, which is 0 if s2=s1 (case where the ref says we should return double) and not 0 if s2 != s1 (case where the ref says we should return point at infinity, which we do)

let hh := mulmod(h, h, p) // h²
let hhh := mulmod(h, hh, p) // h³
let r := addmod(s2, sub(p, s1), p) // r = s2-s1

// x' = r²-h³-2*u1*h²
x3 := addmod(addmod(mulmod(r, r, p), sub(p, hhh), p), sub(p, mulmod(2, mulmod(u1, hh, p), p)), p)
// y' = r*(u1*h²-x')-s1*h³
y3 := addmod(mulmod(r, addmod(mulmod(u1, hh, p), sub(p, x3), p), p), sub(p, mulmod(s1, hhh, p)), p)
// z' = h*z1*z2
z3 := mulmod(h, mulmod(z1, z2, p), p)
}
}

/**
* @dev Point doubling on the jacobian coordinates
* https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates
*/
function _jDouble(uint256 x, uint256 y, uint256 z) private pure returns (uint256 x2, uint256 y2, uint256 z2) {
/// @solidity memory-safe-assembly
assembly {
let p := P
let yy := mulmod(y, y, p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return the point at infinity (0,1,0) if y == 0 at this point?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part should be discussed over a call, but here are my first thought

A jacobian point (x,y,z) represents the carthesian point (x/z², y/z³). Therefore, any jacobian point that has z=0 is at infinity (because x/z² and y/z³ are x/0 and y/0, which is not defined).

For the case of _jDouble, if y=0 we have z' = 2*y*z = 0

So my understanding is that, if y=0 we know the result is a point at infinity, and we can potentially skip some computation is we already know one. However, if we don't skip, and we do the computation "normaly", we get a point at infinity, which is correct. We could have skipped the computation, but given that in our case addmod and mulmod are "cheap native operation" (which is not the case in an x86 machine), its probably ok.

We can try implementing the skip, and see if its actually saving gas, but AFAIK, this is an optimisation issue and not a correctness issue.

let zz := mulmod(z, z, p)
let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴

// x' = m²-2*s
x2 := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p)
// y' = m*(s-x')-8*y⁴
y2 := addmod(mulmod(m, addmod(s, sub(p, x2), p), p), sub(p, mulmod(8, mulmod(yy, yy, p), p)), p)
// z' = 2*y*z
z2 := mulmod(2, mulmod(y, z, p), p)
}
}

/**
* @dev Point multiplication on the jacobian coordinates
*/
function _jMult(
uint256 x,
uint256 y,
uint256 z,
uint256 k
) private pure returns (uint256 x2, uint256 y2, uint256 z2) {
unchecked {
for (uint256 i = 0; i < 256; ++i) {
if (z > 0) {
(x2, y2, z2) = _jDouble(x2, y2, z2);
}
if (k >> 255 > 0) {
(x2, y2, z2) = _jAdd(x2, y2, z2, x, y, z);
}
k <<= 1;
}
}
}

/**
* @dev Compute P·u1 + Q·u2 using the precomputed points for P and Q (see {_preComputeJacobianPoints}).
*
* Uses Strauss Shamir trick for EC multiplication
* https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method
* we optimise on this a bit to do with 2 bits at a time rather than a single bit
* the individual points for a single pass are precomputed
* overall this reduces the number of additions while keeping the same number of doublings
*/
function _jMultShamir(JPoint[16] memory points, uint256 u1, uint256 u2) private view returns (uint256, uint256) {
uint256 x = 0;
uint256 y = 0;
uint256 z = 0;
unchecked {
for (uint256 i = 0; i < 128; ++i) {
if (z > 0) {
(x, y, z) = _jDouble(x, y, z);
(x, y, z) = _jDouble(x, y, z);
}
// Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table.
uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3);
if (pos > 0) {
(x, y, z) = _jAdd(x, y, z, points[pos].x, points[pos].y, points[pos].z);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here the if is optional.

If we remove the if, we are going to load points[0] which is (0,0,0) ... and the _jAdd will skip that as the "neutral element". The if here as a cost. 15/16 we pay it for no real reason (and we still pay the check in _jAdd). 1/16 the if avoids the overhead of a function call.

I'm going to benchmark which one is better and comment that so we don't go back and forward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, skipping the mloads in 1/16 cases is a bigger gain than the loss of the if in the other 15/16 cases. Keeping the if is the more effective solution here

u1 <<= 2;
u2 <<= 2;
}
}
return _affineFromJacobian(x, y, z);
}

/**
* @dev Precompute a matrice of useful jacobian points associated to a given P. This can be seen as a 4x4 matrix
* that contains combinaison of P and G (generator) up to 3 times each. See table below:
Amxx marked this conversation as resolved.
Show resolved Hide resolved
*
* ┌────┬─────────────────────┐
* │ i │ 0 1 2 3 │
* ├────┼─────────────────────┤
* │ 0 │ 0 p 2p 3p │
* │ 4 │ g g+p g+2p g+3p │
* │ 8 │ 2g 2g+p 2g+2p 2g+3p │
* │ 12 │ 3g 3g+p 3g+2p 3g+3p │
* └────┴─────────────────────┘
*/
function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) {
points[0x00] = JPoint(0, 0, 0);
points[0x01] = JPoint(px, py, 1);
points[0x04] = JPoint(GX, GY, 1);
points[0x02] = _jDoublePoint(points[0x01]);
points[0x08] = _jDoublePoint(points[0x04]);
points[0x03] = _jAddPoint(points[0x01], points[0x02]);
points[0x05] = _jAddPoint(points[0x01], points[0x04]);
points[0x06] = _jAddPoint(points[0x02], points[0x04]);
points[0x07] = _jAddPoint(points[0x03], points[0x04]);
points[0x09] = _jAddPoint(points[0x01], points[0x08]);
points[0x0a] = _jAddPoint(points[0x02], points[0x08]);
points[0x0b] = _jAddPoint(points[0x03], points[0x08]);
points[0x0c] = _jAddPoint(points[0x04], points[0x08]);
points[0x0d] = _jAddPoint(points[0x01], points[0x0c]);
points[0x0e] = _jAddPoint(points[0x02], points[0x0c]);
points[0x0f] = _jAddPoint(points[0x03], points[0x0C]);
}

function _jAddPoint(JPoint memory p1, JPoint memory p2) private pure returns (JPoint memory) {
(uint256 x, uint256 y, uint256 z) = _jAdd(p1.x, p1.y, p1.z, p2.x, p2.y, p2.z);
return JPoint(x, y, z);
}

function _jDoublePoint(JPoint memory p) private pure returns (JPoint memory) {
(uint256 x, uint256 y, uint256 z) = _jDouble(p.x, p.y, p.z);
return JPoint(x, y, z);
}

/**
*@dev From Fermat's little theorem https://en.wikipedia.org/wiki/Fermat%27s_little_theorem:
* `a**(p-1) ≡ 1 mod p`. This means that `a**(p-2)` is an inverse of a in Fp.
*/
function _invModN(uint256 value) private view returns (uint256) {
return Math.modExp(value, N2, N);
}

function _invModP(uint256 value) private view returns (uint256) {
return Math.modExp(value, P2, P);
}
}