diff --git a/.gitignore b/.gitignore index b35f8449bf..d31b378152 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,10 @@ _testmain.go *.test *.prof /s2/cmd/_s2sx/sfx-exe + +# Linux perf files +perf.data +perf.data.old + +# gdb history +.gdb_history diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go new file mode 100644 index 0000000000..084513ee6c --- /dev/null +++ b/zstd/_generate/gen.go @@ -0,0 +1,552 @@ +package main + +//go:generate go run gen.go -out seqdec_amd64.s -stubs delme.go -pkg=zstd + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + + _ "github.com/klauspost/compress" + + . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/buildtags" + . "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +// insert extra checks here and there. +const debug = false + +// error reported when mo == 0 && ml > 0 +const errorMatchLenOfsMismatch = 1 + +// error reported when ml > maxMatchLen +const errorMatchLenTooBig = 2 + +const maxMatchLen = 131074 + +func main() { + flag.Parse() + out := flag.Lookup("out") + os.Remove(filepath.Join("..", out.Value.String())) + stub := flag.Lookup("stubs") + if stub.Value.String() != "" { + os.Remove(stub.Value.String()) + defer os.Remove(stub.Value.String()) + } + + Constraint(buildtags.Not("appengine").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + Constraint(buildtags.Term("gc").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + + o := options{ + bmi2: false, + } + o.genDecodeSeqAsm("sequenceDecs_decode_amd64") + o.bmi2 = true + o.genDecodeSeqAsm("sequenceDecs_decode_bmi2") + Generate() + b, err := ioutil.ReadFile(out.Value.String()) + if err != nil { + panic(err) + } + const readOnly = 0444 + err = ioutil.WriteFile(filepath.Join("..", out.Value.String()), b, readOnly) + if err != nil { + panic(err) + } + os.Remove(out.Value.String()) +} + +func debugval(v Op) { + value := reg.R15 + MOVQ(v, value) + INT(Imm(3)) +} + +func debugval32(v Op) { + value := reg.R15L + MOVL(v, value) + INT(Imm(3)) +} + +var assertCounter int + +// assert will insert code if debug is enabled. +// The code should jump to 'ok' is assertion is success. +func assert(fn func(ok LabelRef)) { + if debug { + caller := [100]uintptr{0} + runtime.Callers(2, caller[:]) + frame, _ := runtime.CallersFrames(caller[:]).Next() + + ok := fmt.Sprintf("assert_check_%d_ok_srcline_%d", assertCounter, frame.Line) + fn(LabelRef(ok)) + // Emit several since delve is imprecise. + INT(Imm(3)) + INT(Imm(3)) + Label(ok) + assertCounter++ + } +} + +type options struct { + bmi2 bool +} + +func (o options) genDecodeSeqAsm(name string) { + Package("github.com/klauspost/compress/zstd") + TEXT(name, 0, "func(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int") + Doc(name+" decodes a sequence", "") + Pragma("noescape") + + brValue := GP64() + brBitsRead := GP64() + brOffset := GP64() + llState := GP64() + mlState := GP64() + ofState := GP64() + seqBase := GP64() + + // 1. load bitReader (done once) + brPointerStash := AllocLocal(8) + { + br := Dereference(Param("br")) + brPointer := GP64() + Load(br.Field("value"), brValue) + Load(br.Field("bitsRead"), brBitsRead) + Load(br.Field("off"), brOffset) + Load(br.Field("in").Base(), brPointer) + ADDQ(brOffset, brPointer) // Add current offset to read pointer. + MOVQ(brPointer, brPointerStash) + } + { + ctx := Dereference(Param("ctx")) + Load(ctx.Field("llState"), llState) + Load(ctx.Field("mlState"), mlState) + Load(ctx.Field("ofState"), ofState) + Load(ctx.Field("seqs").Base(), seqBase) + } + + moP := Mem{Base: seqBase, Disp: 2 * 8} // Pointer to current mo + mlP := Mem{Base: seqBase, Disp: 1 * 8} // Pointer to current ml + llP := Mem{Base: seqBase, Disp: 0 * 8} // Pointer to current ll + + // MAIN LOOP: + Label(name + "_main_loop") + + { + brPointer := GP64() + MOVQ(brPointerStash, brPointer) + Comment("Fill bitreader to have enough for the offset.") + o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer) + + Comment("Update offset") + o.updateLength(name+"_of_update", brValue, brBitsRead, ofState, moP) + + // Refill if needed. + Comment("Fill bitreader for match and literal") + o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer) + + Comment("Update match length") + o.updateLength(name+"_ml_update", brValue, brBitsRead, mlState, mlP) + + Comment("Update literal length") + o.updateLength(name+"_ll_update", brValue, brBitsRead, llState, llP) + + Comment("Fill bitreader for state updates") + o.bitreaderFill(name+"_fill_3", brValue, brBitsRead, brOffset, brPointer) + MOVQ(brPointer, brPointerStash) + } + + R14 := GP64() + MOVQ(ofState, R14) // copy ofState, its current value is needed below + // Reload ctx + ctx := Dereference(Param("ctx")) + iteration, err := ctx.Field("iteration").Resolve() + if err != nil { + panic(err) + } + // if ctx.iteration != 0, do update + CMPQ(iteration.Addr, U8(0)) + JZ(LabelRef(name + "_skip_update")) + + // Update states + { + Comment("Update Literal Length State") + o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable") + Comment("Update Match Length State") + o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable") + Comment("Update Offset State") + o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable") + } + Label(name + "_skip_update") + + // mo = s.adjustOffset(mo, ll, moB) + SHRQ(U8(8), R14) // moB (from the ofState before its update) + MOVBQZX(R14.As8(), R14) + + Comment("Adjust offset") + + offset := o.adjustOffset(name+"_adjust", moP, llP, R14) + MOVQ(offset, moP) // Store offset + + Comment("Check values") + ml := GP64() + MOVQ(mlP, ml) + ll := GP64() + MOVQ(llP, ll) + + // Update length + { + length := GP64() + LEAQ(Mem{Base: ml, Index: ll, Scale: 1}, length) + s := Dereference(Param("s")) + seqSizeP, err := s.Field("seqSize").Resolve() + if err != nil { + panic(err) + } + ADDQ(length, seqSizeP.Addr) // s.seqSize += ml + ll + } + + // Reload ctx + ctx = Dereference(Param("ctx")) + litRemainP, err := ctx.Field("litRemain").Resolve() + if err != nil { + panic(err) + } + SUBQ(ll, litRemainP.Addr) // ctx.litRemain -= ll + { + // if ml > maxMatchLen { + // return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + // } + CMPQ(ml, U32(maxMatchLen)) + JA(LabelRef(name + "_error_match_len_too_big")) + } + { + // if mo == 0 && ml > 0 { + // return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + // } + TESTQ(offset, offset) + JNZ(LabelRef(name + "_match_len_ofs_ok")) // mo != 0 + TESTQ(ml, ml) + JNZ(LabelRef(name + "_error_match_len_ofs_mismatch")) + } + + Label(name + "_match_len_ofs_ok") + ADDQ(U8(24), seqBase) // sizof(seqVals) == 3*8 + ctx = Dereference(Param("ctx")) + iterationP, err := ctx.Field("iteration").Resolve() + if err != nil { + panic(err) + } + + DECQ(iterationP.Addr) + JNS(LabelRef(name + "_main_loop")) + + // update bitreader state before returning + br := Dereference(Param("br")) + Store(brValue, br.Field("value")) + Store(brBitsRead.As8(), br.Field("bitsRead")) + Store(brOffset, br.Field("off")) + + Comment("Return success") + o.returnWithCode(0) + + Comment("Return with match length error") + Label(name + "_error_match_len_ofs_mismatch") + o.returnWithCode(errorMatchLenOfsMismatch) + + Comment("Return with match too long error") + Label(name + "_error_match_len_too_big") + o.returnWithCode(errorMatchLenTooBig) +} + +func (o options) returnWithCode(returnCode uint32) { + a, err := ReturnIndex(0).Resolve() + if err != nil { + panic(err) + } + MOVQ(U32(returnCode), a.Addr) + RET() +} + +func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual) { + // bitreader_fill begin + CMPQ(brBitsRead, U8(32)) // b.bitsRead < 32 + JL(LabelRef(name + "_end")) + + CMPQ(brOffset, U8(4)) // b.off >= 4 + JL(LabelRef(name + "_byte_by_byte")) + + // Label(name + "_fast") + SHLQ(U8(32), brValue) // b.value << 32 | uint32(mem) + SUBQ(U8(4), brPointer) + SUBQ(U8(4), brOffset) + SUBQ(U8(32), brBitsRead) + tmp := GP64() + MOVLQZX(Mem{Base: brPointer}, tmp) + ORQ(tmp, brValue) + JMP(LabelRef(name + "_end")) + + Label(name + "_byte_by_byte") + CMPQ(brOffset, U8(0)) /* for b.off > 0 */ + JLE(LabelRef(name + "_end")) + + SHLQ(U8(8), brValue) /* b.value << 8 | uint8(mem) */ + SUBQ(U8(1), brPointer) + SUBQ(U8(1), brOffset) + SUBQ(U8(8), brBitsRead) + + tmp = GP64() + MOVBQZX(Mem{Base: brPointer}, tmp) + ORQ(tmp, brValue) + + JMP(LabelRef(name + "_byte_by_byte")) + + Label(name + "_end") +} + +func (o options) updateLength(name string, brValue, brBitsRead, state reg.GPVirtual, out Mem) { + if o.bmi2 { + DX := GP64() + extr := GP64() + MOVQ(U32(8|(8<<8)), extr) + BEXTRQ(extr, state, DX) // addBits = (state >> 8) &xff + BX := GP64() + MOVQ(brValue, BX) + // TODO: We should be able to extra bits with BEXTRQ + CX := reg.CL + LEAQ(Mem{Base: brBitsRead, Index: DX, Scale: 1}, CX.As64()) // CX: shift = r.bitsRead + n + ROLQ(CX, BX) + BZHIQ(DX.As64(), BX, BX) + MOVQ(CX.As64(), brBitsRead) // br.bitsRead += moB + res := GP64() // AX + MOVQ(state, res) + SHRQ(U8(32), res) // AX = mo (ofState.baselineInt(), that's the higher dword of moState) + ADDQ(BX, res) // AX - mo + br.getBits(moB) + MOVQ(res, out) + } else { + BX := GP64() + CX := reg.CL + AX := reg.RAX + MOVQ(state, AX.As64()) // So we can grab high bytes. + MOVQ(brBitsRead, CX.As64()) + MOVQ(brValue, BX) + SHLQ(CX, BX) // BX = br.value << br.bitsRead (part of getBits) + MOVB(AX.As8H(), CX.As8L()) // CX = moB (ofState.addBits(), that is byte #1 of moState) + ADDQ(CX.As64(), brBitsRead) // br.bitsRead += n (part of getBits) + NEGL(CX.As32()) // CX = 64 - n + SHRQ(CX, BX) // BX = (br.value << br.bitsRead) >> (64 - n) -- getBits() result + SHRQ(U8(32), AX) // AX = mo (ofState.baselineInt(), that's the higher dword of moState) + TESTQ(CX.As64(), CX.As64()) + CMOVQEQ(CX.As64(), BX) // BX is zero if n is zero + + // Check if AX is reasonable + assert(func(ok LabelRef) { + CMPQ(AX, U32(1<<28)) + JB(ok) + }) + // Check if BX is reasonable + assert(func(ok LabelRef) { + CMPQ(BX, U32(1<<28)) + JB(ok) + }) + ADDQ(BX, AX) // AX - mo + br.getBits(moB) + MOVQ(AX, out) // Store result + } +} + +func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtual, table string) { + name = name + "_updateState" + AX := GP64() + MOVBQZX(state.As8(), AX) // AX = nBits + // Check we have a reasonable nBits + assert(func(ok LabelRef) { + CMPQ(AX, U8(9)) + JBE(ok) + }) + + DX := GP64() + MOVQ(state, DX) // TODO: maybe use BEXTR? + SHRQ(U8(16), DX) + MOVWQZX(DX.As16(), DX) + + if !o.bmi2 { + // TODO: Probably reasonable to kip if AX==0s + CMPQ(AX, U8(0)) + JZ(LabelRef(name + "_skip")) + } + + { + lowBits := o.getBits(name+"_getBits", AX, brValue, brBitsRead) + // Check if below tablelog + assert(func(ok LabelRef) { + CMPQ(lowBits, U32(512)) + JB(ok) + }) + ADDQ(lowBits, DX) + } + + Label(name + "_skip") + // Load table pointer + tablePtr := GP64() + Comment("Load ctx." + table) + ctx := Dereference(Param("ctx")) + tableA, err := ctx.Field(table).Base().Resolve() + if err != nil { + panic(err) + } + MOVQ(tableA.Addr, tablePtr) + + // Check if below tablelog + assert(func(ok LabelRef) { + CMPQ(DX, U32(512)) + JB(ok) + }) + // Load new state + MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state) +} + +func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual { + BX := GP64() + CX := reg.CL + if o.bmi2 { + LEAQ(Mem{Base: brBitsRead, Index: nBits, Scale: 1}, CX.As64()) + MOVQ(brValue, BX) + MOVQ(CX.As64(), brBitsRead) + ROLQ(CX, BX) + BZHIQ(nBits, BX, BX) + } else { + MOVQ(brBitsRead, CX.As64()) + ADDQ(nBits, brBitsRead) + MOVQ(brValue, BX) + SHLQ(CX, BX) + MOVQ(nBits, CX.As64()) + NEGQ(CX.As64()) + SHRQ(CX, BX) + TESTQ(nBits, nBits) + CMOVQEQ(nBits, BX) + } + return BX +} + +func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual) (offset reg.GPVirtual) { + s := Dereference(Param("s")) + + po0, _ := s.Field("prevOffset").Index(0).Resolve() + po1, _ := s.Field("prevOffset").Index(1).Resolve() + po2, _ := s.Field("prevOffset").Index(2).Resolve() + offset = GP64() + MOVQ(moP, offset) + { + // if offsetB > 1 { + // s.prevOffset[2] = s.prevOffset[1] + // s.prevOffset[1] = s.prevOffset[0] + // s.prevOffset[0] = offset + // return offset + // } + CMPQ(offsetB, U8(1)) + JBE(LabelRef(name + "_offsetB_1_or_0")) + + // TODO: Test if 1 SSE2 move + write is faster... + tmp, tmp2 := GP64(), GP64() + MOVQ(po0.Addr, tmp) // tmp = s.prevOffset[0] + MOVQ(po1.Addr, tmp2) // tmp2 = s.prevOffset[1] + MOVQ(offset, po0.Addr) // s.prevOffset[0] = offset + MOVQ(tmp, po1.Addr) // s.prevOffset[1] = s.prevOffset[0] + MOVQ(tmp2, po2.Addr) // s.prevOffset[2] = s.prevOffset[1] + JMP(LabelRef(name + "_end")) + } + + Label(name + "_offsetB_1_or_0") + // if litLen == 0 { + // offset++ + // } + { + if true { + CMPQ(llP, U32(0)) + JNE(LabelRef(name + "_offset_maybezero")) + INCQ(offset) + JMP(LabelRef(name + "_offset_nonzero")) + } else { + // No idea why this doesn't work: + tmp := GP64() + LEAQ(Mem{Base: offset, Disp: 1}, tmp) + CMPQ(llP, U32(0)) + CMOVQEQ(tmp, offset) + } + + // if offset == 0 { + // return s.prevOffset[0] + // } + { + Label(name + "_offset_maybezero") + TESTQ(offset, offset) + JNZ(LabelRef(name + "_offset_nonzero")) + MOVQ(po0.Addr, offset) + JMP(LabelRef(name + "_end")) + } + } + Label(name + "_offset_nonzero") + { + // if offset == 3 { + // temp = s.prevOffset[0] - 1 + // } else { + // temp = s.prevOffset[offset] + // } + // + // this if got transformed into: + // + // ofs := offset + // shift := 0 + // if offset == 3 { + // ofs = 0 + // shift = -1 + // } + // temp := s.prevOffset[ofs] + shift + // TODO: This should be easier... + CX, DX, R15 := GP64(), GP64(), GP64() + MOVQ(offset, CX) + XORQ(DX, DX) + MOVQ(I32(-1), R15) + CMPQ(offset, U8(3)) + CMOVQEQ(DX, CX) + CMOVQEQ(R15, DX) + prevOffset := GP64() + LEAQ(po0.Addr, prevOffset) // &prevOffset[0] + ADDQ(Mem{Base: prevOffset, Index: CX, Scale: 8}, DX) + temp := DX + // if temp == 0 { + // temp = 1 + // } + JNZ(LabelRef(name + "_temp_valid")) + MOVQ(U32(1), temp) + + Label(name + "_temp_valid") + // if offset != 1 { + // s.prevOffset[2] = s.prevOffset[1] + // } + CMPQ(offset, U8(1)) + JZ(LabelRef(name + "_skip")) + tmp := GP64() + MOVQ(po1.Addr, tmp) + MOVQ(tmp, po2.Addr) // s.prevOffset[2] = s.prevOffset[1] + + Label(name + "_skip") + // s.prevOffset[1] = s.prevOffset[0] + // s.prevOffset[0] = temp + tmp = GP64() + MOVQ(po0.Addr, tmp) + MOVQ(tmp, po1.Addr) // s.prevOffset[1] = s.prevOffset[0] + MOVQ(temp, po0.Addr) // s.prevOffset[0] = temp + MOVQ(temp, offset) // return temp + } + Label(name + "_end") + return offset +} diff --git a/zstd/_generate/go.mod b/zstd/_generate/go.mod new file mode 100644 index 0000000000..41c4458869 --- /dev/null +++ b/zstd/_generate/go.mod @@ -0,0 +1,10 @@ +module github.com/klauspost/compress/s2/_generate + +go 1.15 + +require ( + github.com/klauspost/compress v1.15.1 + github.com/mmcloughlin/avo v0.4.0 +) + +replace github.com/klauspost/compress => ../.. diff --git a/zstd/_generate/go.sum b/zstd/_generate/go.sum new file mode 100644 index 0000000000..b4b59140f0 --- /dev/null +++ b/zstd/_generate/go.sum @@ -0,0 +1,32 @@ +github.com/mmcloughlin/avo v0.4.0 h1:jeHDRktVD+578ULxWpQHkilor6pkdLF7u7EiTzDbfcU= +github.com/mmcloughlin/avo v0.4.0/go.mod h1:RW9BfYA3TgO9uCdNrKU2h6J8cPD8ZLznvfgHAeszb1s= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 h1:giLT+HuUP/gXYrG2Plg9WTjj4qhfgaW424ZIFog3rlk= +golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= +golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 009d85cef0..39b4de464b 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -98,150 +98,6 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) erro return nil } -// decode sequences from the stream with the provided history. -func (s *sequenceDecs) decode(seqs []seqVals) error { - br := s.br - - // Grab full sizes tables, to avoid bounds checks. - llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] - llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state - s.seqSize = 0 - litRemain := len(s.literals) - maxBlockSize := maxCompressedBlockSize - if s.windowSize < maxBlockSize { - maxBlockSize = s.windowSize - } - for i := range seqs { - var ll, mo, ml int - if br.off > 4+((maxOffsetBits+16+16)>>3) { - // inlined function: - // ll, mo, ml = s.nextFast(br, llState, mlState, ofState) - - // Final will not read from stream. - var llB, mlB, moB uint8 - ll, llB = llState.final() - ml, mlB = mlState.final() - mo, moB = ofState.final() - - // extra bits are stored in reverse order. - br.fillFast() - mo += br.getBits(moB) - if s.maxBits > 32 { - br.fillFast() - } - ml += br.getBits(mlB) - ll += br.getBits(llB) - - if moB > 1 { - s.prevOffset[2] = s.prevOffset[1] - s.prevOffset[1] = s.prevOffset[0] - s.prevOffset[0] = mo - } else { - // mo = s.adjustOffset(mo, ll, moB) - // Inlined for rather big speedup - if ll == 0 { - // There is an exception though, when current sequence's literals_length = 0. - // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2, - // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte. - mo++ - } - - if mo == 0 { - mo = s.prevOffset[0] - } else { - var temp int - if mo == 3 { - temp = s.prevOffset[0] - 1 - } else { - temp = s.prevOffset[mo] - } - - if temp == 0 { - // 0 is not valid; input is corrupted; force offset to 1 - println("WARNING: temp was 0") - temp = 1 - } - - if mo != 1 { - s.prevOffset[2] = s.prevOffset[1] - } - s.prevOffset[1] = s.prevOffset[0] - s.prevOffset[0] = temp - mo = temp - } - } - br.fillFast() - } else { - if br.overread() { - if debugDecoder { - printf("reading sequence %d, exceeded available data\n", i) - } - return io.ErrUnexpectedEOF - } - ll, mo, ml = s.next(br, llState, mlState, ofState) - br.fill() - } - - if debugSequences { - println("Seq", i, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml) - } - // Evaluate. - // We might be doing this async, so do it early. - if mo == 0 && ml > 0 { - return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) - } - if ml > maxMatchLen { - return fmt.Errorf("match len (%d) bigger than max allowed length", ml) - } - s.seqSize += ll + ml - if s.seqSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) - } - litRemain -= ll - if litRemain < 0 { - return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain+ll) - } - seqs[i] = seqVals{ - ll: ll, - ml: ml, - mo: mo, - } - if i == len(seqs)-1 { - // This is the last sequence, so we shouldn't update state. - break - } - - // Manually inlined, ~ 5-20% faster - // Update all 3 states at once. Approx 20% faster. - nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits() - if nBits == 0 { - llState = llTable[llState.newState()&maxTableMask] - mlState = mlTable[mlState.newState()&maxTableMask] - ofState = ofTable[ofState.newState()&maxTableMask] - } else { - bits := br.get32BitsFast(nBits) - lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) - llState = llTable[(llState.newState()+lowBits)&maxTableMask] - - lowBits = uint16(bits >> (ofState.nbBits() & 31)) - lowBits &= bitMask[mlState.nbBits()&15] - mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask] - - lowBits = uint16(bits) & bitMask[ofState.nbBits()&15] - ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask] - } - } - s.seqSize += litRemain - if s.seqSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) - } - err := br.close() - if err != nil { - printf("Closing sequences: %v, %+v\n", err, *br) - } - return err -} - // execute will execute the decoded sequence with the provided history. // The sequence must be evaluated before being sent. func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { diff --git a/zstd/seqdec_amd64.go b/zstd/seqdec_amd64.go new file mode 100644 index 0000000000..b6832ee257 --- /dev/null +++ b/zstd/seqdec_amd64.go @@ -0,0 +1,103 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "fmt" + + "github.com/klauspost/compress/internal/cpuinfo" +) + +type decodeAsmContext struct { + llTable []decSymbol + mlTable []decSymbol + ofTable []decSymbol + llState uint64 + mlState uint64 + ofState uint64 + iteration int + seqs []seqVals + litRemain int +} + +// error reported when mo == 0 && ml > 0 +const errorMatchLenOfsMismatch = 1 + +// error reported when ml > maxMatchLen +const errorMatchLenTooBig = 2 + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. +// +// Please refer to seqdec_generic.go for the reference implementation. +func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. +func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +type sequenceDecs_decode_function = func(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +var sequenceDecs_decode sequenceDecs_decode_function + +func init() { + if cpuinfo.HasBMI2() { + sequenceDecs_decode = sequenceDecs_decode_bmi2 + } else { + sequenceDecs_decode = sequenceDecs_decode_amd64 + } +} + +// decode sequences from the stream without the provided history. +func (s *sequenceDecs) decode(seqs []seqVals) error { + br := s.br + + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + + ctx := decodeAsmContext{ + llTable: s.litLengths.fse.dt[:maxTablesize], + mlTable: s.matchLengths.fse.dt[:maxTablesize], + ofTable: s.offsets.fse.dt[:maxTablesize], + llState: uint64(s.litLengths.state.state), + mlState: uint64(s.matchLengths.state.state), + ofState: uint64(s.offsets.state.state), + seqs: seqs, + iteration: len(seqs) - 1, + litRemain: len(s.literals), + } + + s.seqSize = 0 + + errCode := sequenceDecs_decode(s, br, &ctx) + if errCode != 0 { + i := len(seqs) - ctx.iteration + switch errCode { + case errorMatchLenOfsMismatch: + ml := ctx.seqs[i].ml + return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + + case errorMatchLenTooBig: + ml := ctx.seqs[i].ml + return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + } + + return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode) + } + + if ctx.litRemain < 0 { + return fmt.Errorf("literal count is too big: total available %d, total requested %d", + len(s.literals), len(s.literals)-ctx.litRemain) + } + + s.seqSize += ctx.litRemain + if s.seqSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + } + return err +} diff --git a/zstd/seqdec_amd64.s b/zstd/seqdec_amd64.s new file mode 100644 index 0000000000..d3d6fc9863 --- /dev/null +++ b/zstd/seqdec_amd64.s @@ -0,0 +1,593 @@ +// Code generated by command: go run gen.go -out seqdec_amd64.s -stubs delme.go -pkg=zstd. DO NOT EDIT. + +//go:build !appengine && !noasm && gc && !noasm +// +build !appengine,!noasm,gc,!noasm + +// func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: CMOV +TEXT ·sequenceDecs_decode_amd64(SB), $8-32 + MOVQ br+8(FP), AX + MOVQ 32(AX), DX + MOVBQZX 40(AX), BX + MOVQ 24(AX), SI + MOVQ (AX), AX + ADDQ SI, AX + MOVQ AX, (SP) + MOVQ ctx+16(FP), AX + MOVQ 72(AX), DI + MOVQ 80(AX), R8 + MOVQ 88(AX), R9 + MOVQ 104(AX), R10 + +sequenceDecs_decode_amd64_main_loop: + MOVQ (SP), R11 + + // Fill bitreader to have enough for the offset. + CMPQ BX, $0x20 + JL sequenceDecs_decode_amd64_fill_end + CMPQ SI, $0x04 + JL sequenceDecs_decode_amd64_fill_byte_by_byte + SHLQ $0x20, DX + SUBQ $0x04, R11 + SUBQ $0x04, SI + SUBQ $0x20, BX + MOVLQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_end + +sequenceDecs_decode_amd64_fill_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_amd64_fill_end + SHLQ $0x08, DX + SUBQ $0x01, R11 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_byte_by_byte + +sequenceDecs_decode_amd64_fill_end: + // Update offset + MOVQ R9, AX + MOVQ BX, CX + MOVQ DX, R12 + SHLQ CL, R12 + MOVB AH, CL + ADDQ CX, BX + NEGL CX + SHRQ CL, R12 + SHRQ $0x20, AX + TESTQ CX, CX + CMOVQEQ CX, R12 + ADDQ R12, AX + MOVQ AX, 16(R10) + + // Fill bitreader for match and literal + CMPQ BX, $0x20 + JL sequenceDecs_decode_amd64_fill_2_end + CMPQ SI, $0x04 + JL sequenceDecs_decode_amd64_fill_2_byte_by_byte + SHLQ $0x20, DX + SUBQ $0x04, R11 + SUBQ $0x04, SI + SUBQ $0x20, BX + MOVLQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_2_end + +sequenceDecs_decode_amd64_fill_2_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_amd64_fill_2_end + SHLQ $0x08, DX + SUBQ $0x01, R11 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_2_byte_by_byte + +sequenceDecs_decode_amd64_fill_2_end: + // Update match length + MOVQ R8, AX + MOVQ BX, CX + MOVQ DX, R12 + SHLQ CL, R12 + MOVB AH, CL + ADDQ CX, BX + NEGL CX + SHRQ CL, R12 + SHRQ $0x20, AX + TESTQ CX, CX + CMOVQEQ CX, R12 + ADDQ R12, AX + MOVQ AX, 8(R10) + + // Update literal length + MOVQ DI, AX + MOVQ BX, CX + MOVQ DX, R12 + SHLQ CL, R12 + MOVB AH, CL + ADDQ CX, BX + NEGL CX + SHRQ CL, R12 + SHRQ $0x20, AX + TESTQ CX, CX + CMOVQEQ CX, R12 + ADDQ R12, AX + MOVQ AX, (R10) + + // Fill bitreader for state updates + CMPQ BX, $0x20 + JL sequenceDecs_decode_amd64_fill_3_end + CMPQ SI, $0x04 + JL sequenceDecs_decode_amd64_fill_3_byte_by_byte + SHLQ $0x20, DX + SUBQ $0x04, R11 + SUBQ $0x04, SI + SUBQ $0x20, BX + MOVLQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_3_end + +sequenceDecs_decode_amd64_fill_3_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_amd64_fill_3_end + SHLQ $0x08, DX + SUBQ $0x01, R11 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R11), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_3_byte_by_byte + +sequenceDecs_decode_amd64_fill_3_end: + MOVQ R11, (SP) + MOVQ R9, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_amd64_skip_update + + // Update Literal Length State + MOVBQZX DI, R11 + SHRQ $0x10, DI + MOVWQZX DI, DI + CMPQ R11, $0x00 + JZ sequenceDecs_decode_amd64_llState_updateState_skip + MOVQ BX, CX + ADDQ R11, BX + MOVQ DX, R12 + SHLQ CL, R12 + MOVQ R11, CX + NEGQ CX + SHRQ CL, R12 + TESTQ R11, R11 + CMOVQEQ R11, R12 + ADDQ R12, DI + +sequenceDecs_decode_amd64_llState_updateState_skip: + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(DI*8), DI + + // Update Match Length State + MOVBQZX R8, R11 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + CMPQ R11, $0x00 + JZ sequenceDecs_decode_amd64_mlState_updateState_skip + MOVQ BX, CX + ADDQ R11, BX + MOVQ DX, R12 + SHLQ CL, R12 + MOVQ R11, CX + NEGQ CX + SHRQ CL, R12 + TESTQ R11, R11 + CMOVQEQ R11, R12 + ADDQ R12, R8 + +sequenceDecs_decode_amd64_mlState_updateState_skip: + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Offset State + MOVBQZX R9, R11 + SHRQ $0x10, R9 + MOVWQZX R9, R9 + CMPQ R11, $0x00 + JZ sequenceDecs_decode_amd64_ofState_updateState_skip + MOVQ BX, CX + ADDQ R11, BX + MOVQ DX, R12 + SHLQ CL, R12 + MOVQ R11, CX + NEGQ CX + SHRQ CL, R12 + TESTQ R11, R11 + CMOVQEQ R11, R12 + ADDQ R12, R9 + +sequenceDecs_decode_amd64_ofState_updateState_skip: + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R9*8), R9 + +sequenceDecs_decode_amd64_skip_update: + SHRQ $0x08, AX + MOVBQZX AL, AX + + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 16(R10), R11 + CMPQ AX, $0x01 + JBE sequenceDecs_decode_amd64_adjust_offsetB_1_or_0 + MOVQ 144(CX), AX + MOVQ 152(CX), R12 + MOVQ R11, 144(CX) + MOVQ AX, 152(CX) + MOVQ R12, 160(CX) + JMP sequenceDecs_decode_amd64_adjust_end + +sequenceDecs_decode_amd64_adjust_offsetB_1_or_0: + CMPQ (R10), $0x00000000 + JNE sequenceDecs_decode_amd64_adjust_offset_maybezero + INCQ R11 + JMP sequenceDecs_decode_amd64_adjust_offset_nonzero + +sequenceDecs_decode_amd64_adjust_offset_maybezero: + TESTQ R11, R11 + JNZ sequenceDecs_decode_amd64_adjust_offset_nonzero + MOVQ 144(CX), R11 + JMP sequenceDecs_decode_amd64_adjust_end + +sequenceDecs_decode_amd64_adjust_offset_nonzero: + MOVQ R11, AX + XORQ R12, R12 + MOVQ $-1, R13 + CMPQ R11, $0x03 + CMOVQEQ R12, AX + CMOVQEQ R13, R12 + LEAQ 144(CX), R13 + ADDQ (R13)(AX*8), R12 + JNZ sequenceDecs_decode_amd64_adjust_temp_valid + MOVQ $0x00000001, R12 + +sequenceDecs_decode_amd64_adjust_temp_valid: + CMPQ R11, $0x01 + JZ sequenceDecs_decode_amd64_adjust_skip + MOVQ 152(CX), AX + MOVQ AX, 160(CX) + +sequenceDecs_decode_amd64_adjust_skip: + MOVQ 144(CX), AX + MOVQ AX, 152(CX) + MOVQ R12, 144(CX) + MOVQ R12, R11 + +sequenceDecs_decode_amd64_adjust_end: + MOVQ R11, 16(R10) + + // Check values + MOVQ 8(R10), AX + MOVQ (R10), CX + LEAQ (AX)(CX*1), R12 + MOVQ s+0(FP), R13 + ADDQ R12, 256(R13) + MOVQ ctx+16(FP), R12 + SUBQ CX, 128(R12) + CMPQ AX, $0x00020002 + JA sequenceDecs_decode_amd64_error_match_len_too_big + TESTQ R11, R11 + JNZ sequenceDecs_decode_amd64_match_len_ofs_ok + TESTQ AX, AX + JNZ sequenceDecs_decode_amd64_error_match_len_ofs_mismatch + +sequenceDecs_decode_amd64_match_len_ofs_ok: + ADDQ $0x18, R10 + MOVQ ctx+16(FP), AX + DECQ 96(AX) + JNS sequenceDecs_decode_amd64_main_loop + MOVQ br+8(FP), AX + MOVQ DX, 32(AX) + MOVB BL, 40(AX) + MOVQ SI, 24(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_amd64_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_amd64_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET + +// func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: BMI, BMI2, CMOV +TEXT ·sequenceDecs_decode_bmi2(SB), $8-32 + MOVQ br+8(FP), CX + MOVQ 32(CX), AX + MOVBQZX 40(CX), DX + MOVQ 24(CX), BX + MOVQ (CX), CX + ADDQ BX, CX + MOVQ CX, (SP) + MOVQ ctx+16(FP), CX + MOVQ 72(CX), SI + MOVQ 80(CX), DI + MOVQ 88(CX), R8 + MOVQ 104(CX), R9 + +sequenceDecs_decode_bmi2_main_loop: + MOVQ (SP), R10 + + // Fill bitreader to have enough for the offset. + CMPQ DX, $0x20 + JL sequenceDecs_decode_bmi2_fill_end + CMPQ BX, $0x04 + JL sequenceDecs_decode_bmi2_fill_byte_by_byte + SHLQ $0x20, AX + SUBQ $0x04, R10 + SUBQ $0x04, BX + SUBQ $0x20, DX + MOVLQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_end + +sequenceDecs_decode_bmi2_fill_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_bmi2_fill_end + SHLQ $0x08, AX + SUBQ $0x01, R10 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_byte_by_byte + +sequenceDecs_decode_bmi2_fill_end: + // Update offset + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R11 + MOVQ AX, R12 + LEAQ (DX)(R11*1), CX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + MOVQ CX, DX + MOVQ R8, CX + SHRQ $0x20, CX + ADDQ R12, CX + MOVQ CX, 16(R9) + + // Fill bitreader for match and literal + CMPQ DX, $0x20 + JL sequenceDecs_decode_bmi2_fill_2_end + CMPQ BX, $0x04 + JL sequenceDecs_decode_bmi2_fill_2_byte_by_byte + SHLQ $0x20, AX + SUBQ $0x04, R10 + SUBQ $0x04, BX + SUBQ $0x20, DX + MOVLQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_2_end + +sequenceDecs_decode_bmi2_fill_2_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_bmi2_fill_2_end + SHLQ $0x08, AX + SUBQ $0x01, R10 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_2_byte_by_byte + +sequenceDecs_decode_bmi2_fill_2_end: + // Update match length + MOVQ $0x00000808, CX + BEXTRQ CX, DI, R11 + MOVQ AX, R12 + LEAQ (DX)(R11*1), CX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + MOVQ CX, DX + MOVQ DI, CX + SHRQ $0x20, CX + ADDQ R12, CX + MOVQ CX, 8(R9) + + // Update literal length + MOVQ $0x00000808, CX + BEXTRQ CX, SI, R11 + MOVQ AX, R12 + LEAQ (DX)(R11*1), CX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + MOVQ CX, DX + MOVQ SI, CX + SHRQ $0x20, CX + ADDQ R12, CX + MOVQ CX, (R9) + + // Fill bitreader for state updates + CMPQ DX, $0x20 + JL sequenceDecs_decode_bmi2_fill_3_end + CMPQ BX, $0x04 + JL sequenceDecs_decode_bmi2_fill_3_byte_by_byte + SHLQ $0x20, AX + SUBQ $0x04, R10 + SUBQ $0x04, BX + SUBQ $0x20, DX + MOVLQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_3_end + +sequenceDecs_decode_bmi2_fill_3_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_bmi2_fill_3_end + SHLQ $0x08, AX + SUBQ $0x01, R10 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R10), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_3_byte_by_byte + +sequenceDecs_decode_bmi2_fill_3_end: + MOVQ R10, (SP) + MOVQ R8, R10 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_bmi2_skip_update + + // Update Literal Length State + MOVBQZX SI, R11 + SHRQ $0x10, SI + MOVWQZX SI, SI + LEAQ (DX)(R11*1), CX + MOVQ AX, R12 + MOVQ CX, DX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + ADDQ R12, SI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI + + // Update Match Length State + MOVBQZX DI, R11 + SHRQ $0x10, DI + MOVWQZX DI, DI + LEAQ (DX)(R11*1), CX + MOVQ AX, R12 + MOVQ CX, DX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + ADDQ R12, DI + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(DI*8), DI + + // Update Offset State + MOVBQZX R8, R11 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + LEAQ (DX)(R11*1), CX + MOVQ AX, R12 + MOVQ CX, DX + ROLQ CL, R12 + BZHIQ R11, R12, R12 + ADDQ R12, R8 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 + +sequenceDecs_decode_bmi2_skip_update: + SHRQ $0x08, R10 + MOVBQZX R10, R10 + + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 16(R9), R11 + CMPQ R10, $0x01 + JBE sequenceDecs_decode_bmi2_adjust_offsetB_1_or_0 + MOVQ 144(CX), R10 + MOVQ 152(CX), R12 + MOVQ R11, 144(CX) + MOVQ R10, 152(CX) + MOVQ R12, 160(CX) + JMP sequenceDecs_decode_bmi2_adjust_end + +sequenceDecs_decode_bmi2_adjust_offsetB_1_or_0: + CMPQ (R9), $0x00000000 + JNE sequenceDecs_decode_bmi2_adjust_offset_maybezero + INCQ R11 + JMP sequenceDecs_decode_bmi2_adjust_offset_nonzero + +sequenceDecs_decode_bmi2_adjust_offset_maybezero: + TESTQ R11, R11 + JNZ sequenceDecs_decode_bmi2_adjust_offset_nonzero + MOVQ 144(CX), R11 + JMP sequenceDecs_decode_bmi2_adjust_end + +sequenceDecs_decode_bmi2_adjust_offset_nonzero: + MOVQ R11, R10 + XORQ R12, R12 + MOVQ $-1, R13 + CMPQ R11, $0x03 + CMOVQEQ R12, R10 + CMOVQEQ R13, R12 + LEAQ 144(CX), R13 + ADDQ (R13)(R10*8), R12 + JNZ sequenceDecs_decode_bmi2_adjust_temp_valid + MOVQ $0x00000001, R12 + +sequenceDecs_decode_bmi2_adjust_temp_valid: + CMPQ R11, $0x01 + JZ sequenceDecs_decode_bmi2_adjust_skip + MOVQ 152(CX), R10 + MOVQ R10, 160(CX) + +sequenceDecs_decode_bmi2_adjust_skip: + MOVQ 144(CX), R10 + MOVQ R10, 152(CX) + MOVQ R12, 144(CX) + MOVQ R12, R11 + +sequenceDecs_decode_bmi2_adjust_end: + MOVQ R11, 16(R9) + + // Check values + MOVQ 8(R9), CX + MOVQ (R9), R10 + LEAQ (CX)(R10*1), R12 + MOVQ s+0(FP), R13 + ADDQ R12, 256(R13) + MOVQ ctx+16(FP), R12 + SUBQ R10, 128(R12) + CMPQ CX, $0x00020002 + JA sequenceDecs_decode_bmi2_error_match_len_too_big + TESTQ R11, R11 + JNZ sequenceDecs_decode_bmi2_match_len_ofs_ok + TESTQ CX, CX + JNZ sequenceDecs_decode_bmi2_error_match_len_ofs_mismatch + +sequenceDecs_decode_bmi2_match_len_ofs_ok: + ADDQ $0x18, R9 + MOVQ ctx+16(FP), CX + DECQ 96(CX) + JNS sequenceDecs_decode_bmi2_main_loop + MOVQ br+8(FP), CX + MOVQ AX, 32(CX) + MOVB DL, 40(CX) + MOVQ BX, 24(CX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_bmi2_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_bmi2_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET diff --git a/zstd/seqdec_generic.go b/zstd/seqdec_generic.go new file mode 100644 index 0000000000..62a20b2988 --- /dev/null +++ b/zstd/seqdec_generic.go @@ -0,0 +1,154 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +package zstd + +import ( + "fmt" + "io" +) + +// decode sequences from the stream without the provided history. +func (s *sequenceDecs) decode(seqs []seqVals) error { + br := s.br + + // Grab full sizes tables, to avoid bounds checks. + llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] + llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state + s.seqSize = 0 + litRemain := len(s.literals) + + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + for i := range seqs { + var ll, mo, ml int + if br.off > 4+((maxOffsetBits+16+16)>>3) { + // inlined function: + // ll, mo, ml = s.nextFast(br, llState, mlState, ofState) + + // Final will not read from stream. + var llB, mlB, moB uint8 + ll, llB = llState.final() + ml, mlB = mlState.final() + mo, moB = ofState.final() + + // extra bits are stored in reverse order. + br.fillFast() + mo += br.getBits(moB) + if s.maxBits > 32 { + br.fillFast() + } + ml += br.getBits(mlB) + ll += br.getBits(llB) + + if moB > 1 { + s.prevOffset[2] = s.prevOffset[1] + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = mo + } else { + // mo = s.adjustOffset(mo, ll, moB) + // Inlined for rather big speedup + if ll == 0 { + // There is an exception though, when current sequence's literals_length = 0. + // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2, + // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte. + mo++ + } + + if mo == 0 { + mo = s.prevOffset[0] + } else { + var temp int + if mo == 3 { + temp = s.prevOffset[0] - 1 + } else { + temp = s.prevOffset[mo] + } + + if temp == 0 { + // 0 is not valid; input is corrupted; force offset to 1 + println("WARNING: temp was 0") + temp = 1 + } + + if mo != 1 { + s.prevOffset[2] = s.prevOffset[1] + } + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = temp + mo = temp + } + } + br.fillFast() + } else { + if br.overread() { + if debugDecoder { + printf("reading sequence %d, exceeded available data\n", i) + } + return io.ErrUnexpectedEOF + } + ll, mo, ml = s.next(br, llState, mlState, ofState) + br.fill() + } + + if debugSequences { + println("Seq", i, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml) + } + // Evaluate. + // We might be doing this async, so do it early. + if mo == 0 && ml > 0 { + return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + } + if ml > maxMatchLen { + return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + } + s.seqSize += ll + ml + if s.seqSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) + } + litRemain -= ll + if litRemain < 0 { + return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain+ll) + } + seqs[i] = seqVals{ + ll: ll, + ml: ml, + mo: mo, + } + if i == len(seqs)-1 { + // This is the last sequence, so we shouldn't update state. + break + } + + // Manually inlined, ~ 5-20% faster + // Update all 3 states at once. Approx 20% faster. + nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits() + if nBits == 0 { + llState = llTable[llState.newState()&maxTableMask] + mlState = mlTable[mlState.newState()&maxTableMask] + ofState = ofTable[ofState.newState()&maxTableMask] + } else { + bits := br.get32BitsFast(nBits) + lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) + llState = llTable[(llState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits >> (ofState.nbBits() & 31)) + lowBits &= bitMask[mlState.nbBits()&15] + mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits) & bitMask[ofState.nbBits()&15] + ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask] + } + } + s.seqSize += litRemain + if s.seqSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + } + return err +}