Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: fix inclusion proof verification flake #956

Merged
merged 2 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 52 additions & 18 deletions pkg/api/trillian_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ type Response struct {
getConsistencyProofResult *trillian.GetConsistencyProofResponse
}

func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) {
var root types.LogRootV1
if err := root.UnmarshalBinary(logRoot); err != nil {
return types.LogRootV1{}, err
}
return root, nil
}

func (t *TrillianClient) root() (types.LogRootV1, error) {
rqst := &trillian.GetLatestSignedLogRootRequest{
LogId: t.logID,
Expand All @@ -77,11 +85,7 @@ func (t *TrillianClient) root() (types.LogRootV1, error) {
if err != nil {
return types.LogRootV1{}, err
}
var root types.LogRootV1
if err := root.UnmarshalBinary(resp.SignedLogRoot.LogRoot); err != nil {
return types.LogRootV1{}, err
}
return root, nil
return unmarshalLogRoot(resp.SignedLogRoot.LogRoot)
}

func (t *TrillianClient) addLeaf(byteValue []byte) *Response {
Expand Down Expand Up @@ -210,11 +214,19 @@ func (t *TrillianClient) getLeafAndProofByIndex(index int64) *Response {
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

root, err := t.root()
rootResp := t.getLatest(0)
if rootResp.err != nil {
return &Response{
status: status.Code(rootResp.err),
err: rootResp.err,
}
}

root, err := unmarshalLogRoot(rootResp.getLatestResult.SignedLogRoot.LogRoot)
if err != nil {
return &Response{
status: status.Code(err),
err: err,
status: status.Code(rootResp.err),
err: rootResp.err,
}
}

Expand All @@ -232,24 +244,39 @@ func (t *TrillianClient) getLeafAndProofByIndex(index int64) *Response {
err: err,
}
}
return &Response{
status: status.Code(err),
err: err,
getLeafAndProofResult: &trillian.GetEntryAndProofResponse{
Proof: resp.Proof,
Leaf: resp.Leaf,
SignedLogRoot: rootResp.getLatestResult.SignedLogRoot,
},
}
}

return &Response{
status: status.Code(err),
err: err,
getLeafAndProofResult: resp,
status: status.Code(err),
err: err,
}
}

func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

root, err := t.root()
rootResp := t.getLatest(0)
if rootResp.err != nil {
return &Response{
status: status.Code(rootResp.err),
err: rootResp.err,
}
}
root, err := unmarshalLogRoot(rootResp.getLatestResult.SignedLogRoot.LogRoot)
if err != nil {
return &Response{
status: status.Code(err),
err: err,
status: status.Code(rootResp.err),
err: rootResp.err,
}
}

Expand All @@ -263,20 +290,27 @@ func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
if resp != nil {
v := client.NewLogVerifier(rfc6962.DefaultHasher)
for _, proof := range resp.Proof {

if err := v.VerifyInclusionByHash(&root, hashValue, proof); err != nil {
return &Response{
status: status.Code(err),
err: err,
}
}
}
// Return an inclusion proof response with the requested
return &Response{
status: status.Code(err),
err: err,
getProofResult: &trillian.GetInclusionProofByHashResponse{
Proof: resp.Proof,
SignedLogRoot: rootResp.getLatestResult.SignedLogRoot,
},
}
}

return &Response{
status: status.Code(err),
err: err,
getProofResult: resp,
status: status.Code(err),
err: err,
}
}

Expand Down
64 changes: 64 additions & 0 deletions tests/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"golang.org/x/sync/errgroup"
"io/ioutil"
"os"
"os/exec"
Expand Down Expand Up @@ -758,3 +759,66 @@ func TestTufVerifyUpload(t *testing.T) {
out = runCli(t, "search", "--public-key", rootPath, "--pki-format", "tuf")
outputContains(t, out, uuid)
}

// Regression test for https://github.com/sigstore/rekor/pull/956
// Requesting an inclusion proof concurrently with an entry write triggers
// a race where the inclusion proof returned does not verify because the
// tree head changes.
func TestInclusionProofRace(t *testing.T) {
// Create a random artifact and sign it.
artifactPath := filepath.Join(t.TempDir(), "artifact")
sigPath := filepath.Join(t.TempDir(), "signature.asc")

createdX509SignedArtifact(t, artifactPath, sigPath)
dataBytes, _ := ioutil.ReadFile(artifactPath)
h := sha256.Sum256(dataBytes)
dataSHA := hex.EncodeToString(h[:])

// Write the public key to a file
pubPath := filepath.Join(t.TempDir(), "pubKey.asc")
if err := ioutil.WriteFile(pubPath, []byte(rsaCert), 0644); err != nil {
t.Fatal(err)
}

// Upload an entry
runCli(t, "upload", "--type=hashedrekord", "--pki-format=x509", "--artifact-hash", dataSHA, "--signature", sigPath, "--public-key", pubPath)

// Constantly uploads new signatures on an entry.
var uploadRoutine = func(pubPath string) error {
// Create a random artifact and sign it.
artifactPath := filepath.Join(t.TempDir(), "artifact")
sigPath := filepath.Join(t.TempDir(), "signature.asc")

createdX509SignedArtifact(t, artifactPath, sigPath)
dataBytes, _ := ioutil.ReadFile(artifactPath)
h := sha256.Sum256(dataBytes)
dataSHA := hex.EncodeToString(h[:])

// Upload an entry
out := runCli(t, "upload", "--type=hashedrekord", "--pki-format=x509", "--artifact-hash", dataSHA, "--signature", sigPath, "--public-key", pubPath)
outputContains(t, out, "Created entry at")

return nil
}

// Attempts to verify the original entry.
var verifyRoutine = func(dataSHA, sigPath, pubPath string) error {
out := runCli(t, "verify", "--type=hashedrekord", "--pki-format=x509", "--artifact-hash", dataSHA, "--signature", sigPath, "--public-key", pubPath)

if strings.Contains(out, "calculated root") || strings.Contains(out, "wrong") {
return fmt.Errorf(out)
}

return nil
}

var g errgroup.Group
for i := 0; i < 50; i++ {
g.Go(func() error { return uploadRoutine(pubPath) })
g.Go(func() error { return verifyRoutine(dataSHA, sigPath, pubPath) })
}

if err := g.Wait(); err != nil {
t.Fatal(err)
}
}