Skip to content

Commit

Permalink
zstd: Fix amd64 not always detecting corrupt data (#785)
Browse files Browse the repository at this point in the history
* zstd: Fix amd64 not always detecting corrupt data

Fix undetected corrupt data in amd64 assembly.

In rare cases overreads would not get returned as errors, if a multiple of 256 bits was overread.
This would make the "bitsread" equal the expected 64.

Whenever all bytes has been read from memory we start checking if more than 64 bits has been read on every fill. This ensures that an overflow can never occur.

No invalid memory was accessed, this is merely a question if errors are reported.

Fixes https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=57290
  • Loading branch information
klauspost committed Mar 22, 2023
1 parent 7633d62 commit 69a8ecc
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 23 deletions.
4 changes: 2 additions & 2 deletions internal/fuzz/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func AddFromZip(f *testing.F, filename string, t InputType, short bool) {
t = TypeRaw // Fallback
if len(b) >= 4 {
sz := binary.BigEndian.Uint32(b)
if sz == uint32(len(b))-4 {
f.Add(b[4:])
if sz <= uint32(len(b))-4 {
f.Add(b[4 : 4+sz])
continue
}
}
Expand Down
22 changes: 18 additions & 4 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const errorNotEnoughLiterals = 4
// error reported when capacity of `out` is too small
const errorNotEnoughSpace = 5

// error reported when bits are overread.
const errorOverread = 6

const maxMatchLen = 131074

// size of struct seqVals
Expand Down Expand Up @@ -247,8 +250,9 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
{
brPointer := GP64()
MOVQ(brPointerStash, brPointer)

Comment("Fill bitreader to have enough for the offset and match length.")
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer)
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))

Comment("Update offset")
// Up to 32 extra bits
Expand All @@ -261,7 +265,7 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
// If we need more than 56 in total, we must refill here.
if !o.fiftysix {
Comment("Fill bitreader to have enough for the remaining")
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer)
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))
}

Comment("Update literal length")
Expand Down Expand Up @@ -502,6 +506,12 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
o.returnWithCode(errorNotEnoughLiterals)
}

Comment("Return with overread error")
{
Label("error_overread")
o.returnWithCode(errorOverread)
}

if !o.useSeqs {
Comment("Return with not enough output space error")
Label("error_not_enough_space")
Expand Down Expand Up @@ -529,7 +539,7 @@ func (o options) returnWithCode(returnCode uint32) {
}

// bitreaderFill will make sure at least 56 bits are available.
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual) {
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual, overread LabelRef) {
// bitreader_fill begin
CMPQ(brOffset, U8(8)) // b.off >= 8
JL(LabelRef(name + "_byte_by_byte"))
Expand All @@ -545,7 +555,7 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi

Label(name + "_byte_by_byte")
CMPQ(brOffset, U8(0)) /* for b.off > 0 */
JLE(LabelRef(name + "_end"))
JLE(LabelRef(name + "_check_overread"))

CMPQ(brBitsRead, U8(7)) /* for brBitsRead > 7 */
JLE(LabelRef(name + "_end"))
Expand All @@ -565,6 +575,10 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi
}
JMP(LabelRef(name + "_byte_by_byte"))

Label(name + "_check_overread")
CMPQ(brBitsRead, U8(64))
JA(overread)

Label(name + "_end")
}

Expand Down
12 changes: 10 additions & 2 deletions zstd/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func FuzzDecAllNoBMI2(f *testing.F) {
func FuzzDecoder(f *testing.F) {
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
//fuzz.AddFromZip(f, "testdata/fuzz/decode-oss.zip", fuzz.TypeOSSFuzz, false)

brLow := newBytesReader(nil)
brHi := newBytesReader(nil)
Expand All @@ -92,18 +93,25 @@ func FuzzDecoder(f *testing.F) {
}
defer decHi.Close()

if debugDecoder {
fmt.Println("LOW CONCURRENT")
}
b1, err1 := io.ReadAll(decLow)

if debugDecoder {
fmt.Println("HI NOT CONCURRENT")
}
b2, err2 := io.ReadAll(decHi)
if err1 != err2 {
if (err1 == nil) != (err2 == nil) {
t.Errorf("err low: %v, hi: %v", err1, err2)
t.Errorf("err low concurrent: %v, hi: %v", err1, err2)
}
}
if err1 != nil {
b1, b2 = b1[:0], b2[:0]
}
if !bytes.Equal(b1, b2) {
t.Fatalf("Output mismatch, low: %v, hi: %v", err1, err2)
t.Fatalf("Output mismatch, low concurrent: %v, hi: %v", err1, err2)
}
})
}
Expand Down
5 changes: 4 additions & 1 deletion zstd/seqdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,12 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
maxBlockSize = s.windowSize
}

if debugDecoder {
println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream")
}
for i := seqs - 1; i >= 0; i-- {
if br.overread() {
printf("reading sequence %d, exceeded available data\n", seqs-i)
printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain())
return io.ErrUnexpectedEOF
}
var ll, mo, ml int
Expand Down
16 changes: 16 additions & 0 deletions zstd/seqdec_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package zstd

import (
"fmt"
"io"

"github.com/klauspost/compress/internal/cpuinfo"
)
Expand Down Expand Up @@ -134,6 +135,9 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
ctx.ll, ctx.litRemain+ctx.ll)

case errorOverread:
return true, io.ErrUnexpectedEOF

case errorNotEnoughSpace:
size := ctx.outPosition + ctx.ll + ctx.ml
if debugDecoder {
Expand Down Expand Up @@ -202,6 +206,9 @@ const errorNotEnoughLiterals = 4
// error reported when capacity of `out` is too small
const errorNotEnoughSpace = 5

// error reported when bits are overread.
const errorOverread = 6

// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
//
// Please refer to seqdec_generic.go for the reference implementation.
Expand Down Expand Up @@ -247,6 +254,10 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
litRemain: len(s.literals),
}

if debugDecoder {
println("decode: decoding", len(seqs), "sequences", br.remain(), "bits remain on stream")
}

s.seqSize = 0
lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
var errCode int
Expand Down Expand Up @@ -277,6 +288,8 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
case errorNotEnoughLiterals:
ll := ctx.seqs[i].ll
return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
case errorOverread:
return io.ErrUnexpectedEOF
}

return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
Expand All @@ -291,6 +304,9 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
if s.seqSize > maxBlockSize {
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
if debugDecoder {
println("decode: ", br.remain(), "bits remain on stream. code:", errCode)
}
err := br.close()
if err != nil {
printf("Closing sequences: %v, %+v\n", err, *br)
Expand Down

0 comments on commit 69a8ecc

Please sign in to comment.