diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 5f7de8007b519..527d9f4f470ba 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -705,6 +705,8 @@ func (c *bls12381G1Add) Run(input []byte) ([]byte, error) { return nil, err } + // No need to check the subgroup here, as specified by EIP-2537 + // Compute r = p_0 + p_1 p0.Add(p0, p1) @@ -734,6 +736,11 @@ func (c *bls12381G1Mul) Run(input []byte) ([]byte, error) { if p0, err = decodePointG1(input[:128]); err != nil { return nil, err } + // 'point is on curve' check already done, + // Here we need to apply subgroup checks. + if !p0.IsInSubGroup() { + return nil, errBLS12381G1PointSubgroup + } // Decode scalar value e := new(big.Int).SetBytes(input[128:]) @@ -787,6 +794,11 @@ func (c *bls12381G1MultiExp) Run(input []byte) ([]byte, error) { if err != nil { return nil, err } + // 'point is on curve' check already done, + // Here we need to apply subgroup checks. + if !p.IsInSubGroup() { + return nil, errBLS12381G1PointSubgroup + } points[i] = *p // Decode scalar value scalars[i] = *new(fr.Element).SetBytes(input[t1:t2]) @@ -827,6 +839,8 @@ func (c *bls12381G2Add) Run(input []byte) ([]byte, error) { return nil, err } + // No need to check the subgroup here, as specified by EIP-2537 + // Compute r = p_0 + p_1 r := new(bls12381.G2Affine) r.Add(p0, p1) @@ -857,6 +871,11 @@ func (c *bls12381G2Mul) Run(input []byte) ([]byte, error) { if p0, err = decodePointG2(input[:256]); err != nil { return nil, err } + // 'point is on curve' check already done, + // Here we need to apply subgroup checks. + if !p0.IsInSubGroup() { + return nil, errBLS12381G2PointSubgroup + } // Decode scalar value e := new(big.Int).SetBytes(input[256:]) @@ -910,6 +929,11 @@ func (c *bls12381G2MultiExp) Run(input []byte) ([]byte, error) { if err != nil { return nil, err } + // 'point is on curve' check already done, + // Here we need to apply subgroup checks. + if !p.IsInSubGroup() { + return nil, errBLS12381G2PointSubgroup + } points[i] = *p // Decode scalar value scalars[i] = *new(fr.Element).SetBytes(input[t1:t2])