From 2081fbbf3372a00437b0eadc8ad53f2f17e035a8 Mon Sep 17 00:00:00 2001 From: priyawadhwa Date: Tue, 30 Aug 2022 15:26:10 -0400 Subject: [PATCH] Validate tree ID on calls to /api/v1/log/entries/retrieve (#1017) * Validate tree ID on calls to /api/v1/log/entries/retrieve Signed-off-by: Priya Wadhwa * Don't unnecessarily extract uuid Signed-off-by: Priya Wadhwa Signed-off-by: Priya Wadhwa --- pkg/api/entries.go | 16 +++++++---- tests/e2e_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/pkg/api/entries.go b/pkg/api/entries.go index abc3cefad..c16220389 100644 --- a/pkg/api/entries.go +++ b/pkg/api/entries.go @@ -330,15 +330,19 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo var searchHashes [][]byte for _, entryID := range params.Entry.EntryUUIDs { - uuid, err := sharding.GetUUIDFromIDString(entryID) - if err != nil { - return handleRekorAPIError(params, http.StatusBadRequest, err, fmt.Sprintf("could not get UUID from ID string %v", entryID)) - } - if logEntry, err := retrieveLogEntry(httpReqCtx, entryID); err == nil { + if sharding.ValidateEntryID(entryID) == nil { + logEntry, err := retrieveLogEntry(httpReqCtx, entryID) + if err != nil { + return handleRekorAPIError(params, http.StatusBadRequest, err, fmt.Sprintf("error getting log entry for %s", entryID)) + } resultPayload = append(resultPayload, logEntry) continue } - // If we couldn't get the entry, search for the hash later + // At this point, check if we got a uuid instead of an EntryID, so search for the hash later + uuid := entryID + if err := sharding.ValidateUUID(uuid); err != nil { + return handleRekorAPIError(params, http.StatusBadRequest, err, fmt.Sprintf("validating uuid %s", uuid)) + } hash, err := hex.DecodeString(uuid) if err != nil { return handleRekorAPIError(params, http.StatusBadRequest, err, malformedUUID) diff --git a/tests/e2e_test.go b/tests/e2e_test.go index 545a90c44..34d018b42 100644 --- a/tests/e2e_test.go +++ b/tests/e2e_test.go @@ -279,12 +279,7 @@ func TestGetCLI(t *testing.T) { outputContains(t, out, uuid) // Exercise GET with the new EntryID (TreeID + UUID) - out = runCli(t, "loginfo") - tidStr := strings.TrimSpace(strings.Split(out, "TreeID: ")[1]) - tid, err := strconv.ParseInt(tidStr, 10, 64) - if err != nil { - t.Errorf(err.Error()) - } + tid := getTreeID(t) entryID, err := sharding.CreateEntryIDFromParts(fmt.Sprintf("%x", tid), uuid) if err != nil { t.Error(err) @@ -1219,3 +1214,65 @@ func getBody(t *testing.T, limit int) []byte { s += "]}" return []byte(s) } + +func getTreeID(t *testing.T) int64 { + out := runCli(t, "loginfo") + tidStr := strings.TrimSpace(strings.Split(out, "TreeID: ")[1]) + tid, err := strconv.ParseInt(tidStr, 10, 64) + if err != nil { + t.Errorf(err.Error()) + } + t.Log("Tree ID:", tid) + return tid +} + +// This test confirms that we validate tree ID when using the /api/v1/log/entries/retrieve endpoint +// https://github.com/sigstore/rekor/issues/1014 +func TestSearchValidateTreeID(t *testing.T) { + // Create something and add it to the log + artifactPath := filepath.Join(t.TempDir(), "artifact") + sigPath := filepath.Join(t.TempDir(), "signature.asc") + + createdPGPSignedArtifact(t, artifactPath, sigPath) + + // Write the public key to a file + pubPath := filepath.Join(t.TempDir(), "pubKey.asc") + if err := ioutil.WriteFile(pubPath, []byte(publicKey), 0644); err != nil { + t.Fatal(err) + } + out := runCli(t, "upload", "--artifact", artifactPath, "--signature", sigPath, "--public-key", pubPath) + outputContains(t, out, "Created entry at") + + uuid, err := sharding.GetUUIDFromIDString(getUUIDFromUploadOutput(t, out)) + if err != nil { + t.Error(err) + } + // Make sure we can get by Entry ID + tid := getTreeID(t) + entryID, err := sharding.CreateEntryIDFromParts(fmt.Sprintf("%x", tid), uuid) + if err != nil { + t.Fatal(err) + } + body := "{\"entryUUIDs\":[\"%s\"]}" + resp, err := http.Post("http://localhost:3000/api/v1/log/entries/retrieve", "application/json", bytes.NewBuffer([]byte(fmt.Sprintf(body, entryID.ReturnEntryIDString())))) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected 200 status code but got %d", resp.StatusCode) + } + + // Make sure we fail with a random tree ID + fakeTID := tid + 1 + entryID, err = sharding.CreateEntryIDFromParts(fmt.Sprintf("%x", fakeTID), uuid) + if err != nil { + t.Fatal(err) + } + resp, err = http.Post("http://localhost:3000/api/v1/log/entries/retrieve", "application/json", bytes.NewBuffer([]byte(fmt.Sprintf(body, entryID.ReturnEntryIDString())))) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 400 { + t.Fatalf("expected 400 status code but got %d", resp.StatusCode) + } +}