Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Fix crash on amd64 (no BMI) + Go fuzz test #645

Merged
merged 10 commits into from Jul 20, 2022
116 changes: 73 additions & 43 deletions zstd/_generate/gen.go
Expand Up @@ -183,6 +183,11 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
ec.moPtr = moP
ec.mlPtr = mlP
ec.llPtr = llP
zero := GP64()
XORQ(zero, zero)
MOVQ(zero, moP)
MOVQ(zero, mlP)
MOVQ(zero, llP)

ec.outBase = GP64()
ec.outEndPtr = AllocLocal(8)
Expand Down Expand Up @@ -338,11 +343,14 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
Comment("Adjust offset")

var offset reg.GPVirtual
end := LabelRef(name + "_after_adjust")
if o.useSeqs {
offset = o.adjustOffset(name+"_adjust", moP, llP, R14, &offsets)
offset = o.adjustOffset(name+"_adjust", moP, llP, R14, &offsets, end)
} else {
offset = o.adjustOffsetInMemory(name+"_adjust", moP, llP, R14)
offset = o.adjustOffsetInMemory(name+"_adjust", moP, llP, R14, end)
}
Label(name + "_after_adjust")

MOVQ(offset, moP) // Store offset

Comment("Check values")
Expand Down Expand Up @@ -586,26 +594,25 @@ func (o options) updateLength(name string, brValue, brBitsRead, state reg.GPVirt
MOVQ(state, AX.As64()) // So we can grab high bytes.
MOVQ(brBitsRead, CX.As64())
MOVQ(brValue, BX)
SHLQ(CX, BX) // BX = br.value << br.bitsRead (part of getBits)
MOVB(AX.As8H(), CX.As8L()) // CX = moB (ofState.addBits(), that is byte #1 of moState)
ADDQ(CX.As64(), brBitsRead) // br.bitsRead += n (part of getBits)
NEGL(CX.As32()) // CX = 64 - n
SHRQ(CX, BX) // BX = (br.value << br.bitsRead) >> (64 - n) -- getBits() result
SHRQ(U8(32), AX) // AX = mo (ofState.baselineInt(), that's the higher dword of moState)
SHLQ(CX, BX) // BX = br.value << br.bitsRead (part of getBits)
MOVB(AX.As8H(), CX.As8L()) // CX = moB (ofState.addBits(), that is byte #1 of moState)
SHRQ(U8(32), AX) // AX = mo (ofState.baselineInt(), that's the higher dword of moState)
// If addBits == 0, skip
TESTQ(CX.As64(), CX.As64())
CMOVQEQ(CX.As64(), BX) // BX is zero if n is zero
JZ(LabelRef(name + "_zero"))

// Check if AX is reasonable
assert(func(ok LabelRef) {
CMPQ(AX, U32(1<<28))
JB(ok)
})
// Check if BX is reasonable
assert(func(ok LabelRef) {
CMPQ(BX, U32(1<<28))
JB(ok)
})
ADDQ(BX, AX) // AX - mo + br.getBits(moB)
ADDQ(CX.As64(), brBitsRead) // br.bitsRead += n (part of getBits)
// If overread, skip
CMPQ(brBitsRead, U8(64))
JA(LabelRef(name + "_zero"))
CMPQ(CX.As64(), U8(64))
JAE(LabelRef(name + "_zero"))

NEGQ(CX.As64()) // CX = 64 - n
SHRQ(CX, BX) // BX = (br.value << br.bitsRead) >> (64 - n) -- getBits() result
ADDQ(BX, AX) // AX - mo + br.getBits(moB)

Label(name + "_zero")
MOVQ(AX, out) // Store result
}
}
Expand Down Expand Up @@ -717,7 +724,7 @@ func (o options) getBits(nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual
return BX
}

func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) {
func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual, end LabelRef) (offset reg.GPVirtual) {
offset = GP64()
MOVQ(moP, offset)
{
Expand All @@ -733,7 +740,7 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
MOVQ(offsets[1], offsets[2]) // s.prevOffset[2] = s.prevOffset[1]
MOVQ(offsets[0], offsets[1]) // s.prevOffset[1] = s.prevOffset[0]
MOVQ(offset, offsets[0]) // s.prevOffset[0] = offset
JMP(LabelRef(name + "_end"))
JMP(end)
}

Label(name + "_offsetB_1_or_0")
Expand Down Expand Up @@ -762,7 +769,7 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
TESTQ(offset, offset)
JNZ(LabelRef(name + "_offset_nonzero"))
MOVQ(offsets[0], offset)
JMP(LabelRef(name + "_end"))
JMP(end)
}
}
Label(name + "_offset_nonzero")
Expand Down Expand Up @@ -821,13 +828,13 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
MOVQ(temp, offsets[0])
MOVQ(temp, offset) // return temp
}
Label(name + "_end")
JMP(end)
return offset
}

// adjustOffsetInMemory is an adjustOffset version that does not cache prevOffset values in registers.
// It fetches and stores values directly into the fields of `sequenceDecs` structure.
func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPVirtual) (offset reg.GPVirtual) {
func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPVirtual, end LabelRef) (offset reg.GPVirtual) {
s := Dereference(Param("s"))

po0, _ := s.Field("prevOffset").Index(0).Resolve()
Expand All @@ -849,26 +856,19 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
MOVUPS(po0.Addr, tmp) // tmp = (s.prevOffset[0], s.prevOffset[1])
MOVQ(offset, po0.Addr) // s.prevOffset[0] = offset
MOVUPS(tmp, po1.Addr) // s.prevOffset[1], s.prevOffset[2] = s.prevOffset[0], s.prevOffset[1]
JMP(LabelRef(name + "_end"))
JMP(end)
}

Label(name + "_offsetB_1_or_0")
// if litLen == 0 {
// offset++
// }

{
if true {
CMPQ(llP, U32(0))
JNE(LabelRef(name + "_offset_maybezero"))
INCQ(offset)
JMP(LabelRef(name + "_offset_nonzero"))
} else {
// No idea why this doesn't work:
tmp := GP64()
LEAQ(Mem{Base: offset, Disp: 1}, tmp)
CMPQ(llP, U32(0))
CMOVQEQ(tmp, offset)
}
CMPQ(llP, U32(0))
JNE(LabelRef(name + "_offset_maybezero"))
INCQ(offset)
JMP(LabelRef(name + "_offset_nonzero"))

// if offset == 0 {
// return s.prevOffset[0]
Expand All @@ -878,11 +878,27 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
TESTQ(offset, offset)
JNZ(LabelRef(name + "_offset_nonzero"))
MOVQ(po0.Addr, offset)
JMP(LabelRef(name + "_end"))
JMP(end)
}
}
Label(name + "_offset_nonzero")
{
// Offset must be 1 -> 3
assert(func(ok LabelRef) {
// Test is above or equal (shouldn't be equal)
CMPQ(offset, U32(0))
JAE(ok)
})
assert(func(ok LabelRef) {
// Check if Above 0.
CMPQ(offset, U32(0))
JA(ok)
})
assert(func(ok LabelRef) {
// Check if Below or Equal to 3.
CMPQ(offset, U32(3))
JBE(ok)
})
// if offset == 3 {
// temp = s.prevOffset[0] - 1
// } else {
Expand All @@ -906,9 +922,23 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
CMPQ(offset, U8(3))
CMOVQEQ(DX, CX)
CMOVQEQ(R15, DX)
prevOffset := GP64()
LEAQ(po0.Addr, prevOffset) // &prevOffset[0]
ADDQ(Mem{Base: prevOffset, Index: CX, Scale: 8}, DX)
assert(func(ok LabelRef) {
CMPQ(CX, U32(0))
JAE(ok)
})
assert(func(ok LabelRef) {
CMPQ(CX, U32(3))
JB(ok)
})
if po0.Addr.Index != nil {
// Use temporary (not currently needed)
prevOffset := GP64()
LEAQ(po0.Addr, prevOffset) // &prevOffset[0]
ADDQ(Mem{Base: prevOffset, Index: CX, Scale: 8}, DX)
} else {
ADDQ(Mem{Base: po0.Addr.Base, Disp: po0.Addr.Disp, Index: CX, Scale: 8}, DX)
}

temp := DX
// if temp == 0 {
// temp = 1
Expand All @@ -935,7 +965,7 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
MOVQ(temp, po0.Addr) // s.prevOffset[0] = temp
MOVQ(temp, offset) // return temp
}
Label(name + "_end")
JMP(end)
return offset
}

Expand Down
15 changes: 9 additions & 6 deletions zstd/bytebuf.go
Expand Up @@ -23,7 +23,7 @@ type byteBuffer interface {
readByte() (byte, error)

// Skip n bytes.
skipN(n int) error
skipN(n int64) error
}

// in-memory buffer
Expand Down Expand Up @@ -62,9 +62,12 @@ func (b *byteBuf) readByte() (byte, error) {
return r, nil
}

func (b *byteBuf) skipN(n int) error {
func (b *byteBuf) skipN(n int64) error {
bb := *b
if len(bb) < n {
if n < 0 {
return fmt.Errorf("negative skip (%d) requested", n)
}
if int64(len(bb)) < n {
return io.ErrUnexpectedEOF
}
*b = bb[n:]
Expand Down Expand Up @@ -120,9 +123,9 @@ func (r *readerWrapper) readByte() (byte, error) {
return r.tmp[0], nil
}

func (r *readerWrapper) skipN(n int) error {
n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
if n2 != int64(n) {
func (r *readerWrapper) skipN(n int64) error {
n2, err := io.CopyN(ioutil.Discard, r.r, n)
if n2 != n {
err = io.ErrUnexpectedEOF
}
return err
Expand Down
41 changes: 23 additions & 18 deletions zstd/dict_test.go
Expand Up @@ -13,24 +13,7 @@ import (
func TestDecoder_SmallDict(t *testing.T) {
// All files have CRC
zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
var dicts [][]byte
for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".dict") {
continue
}
func() {
r, err := tt.Open()
if err != nil {
t.Fatal(err)
}
defer r.Close()
in, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
dicts = append(dicts, in)
}()
}
dicts := readDicts(t, zr)
dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -453,3 +436,25 @@ func TestDecoder_MoreDicts2(t *testing.T) {
})
}
}

func readDicts(tb testing.TB, zr *zip.Reader) [][]byte {
var dicts [][]byte
for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".dict") {
continue
}
func() {
r, err := tt.Open()
if err != nil {
tb.Fatal(err)
}
defer r.Close()
in, err := ioutil.ReadAll(r)
if err != nil {
tb.Fatal(err)
}
dicts = append(dicts, in)
}()
}
return dicts
}
2 changes: 1 addition & 1 deletion zstd/framedec.go
Expand Up @@ -106,7 +106,7 @@ func (d *frameDec) reset(br byteBuffer) error {
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
println("Skipping frame with", n, "bytes.")
err = br.skipN(int(n))
err = br.skipN(int64(n))
if err != nil {
if debugDecoder {
println("Reading discarded frame", err)
Expand Down
4 changes: 2 additions & 2 deletions zstd/fse_decoder_amd64.go
Expand Up @@ -34,8 +34,8 @@ const (
// buildDtable will build the decoding table.
func (s *fseDecoder) buildDtable() error {
ctx := buildDtableAsmContext{
stateTable: (*uint16)(&s.stateTable[0]),
norm: (*int16)(&s.norm[0]),
stateTable: &s.stateTable[0],
norm: &s.norm[0],
dt: (*uint64)(&s.dt[0]),
}
code := buildDtable_asm(s, &ctx)
Expand Down