diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go index 23a26b21d4..c7fe02b305 100644 --- a/zstd/_generate/gen.go +++ b/zstd/_generate/gen.go @@ -80,6 +80,7 @@ func main() { decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_amd64") decodeSync.setBMI2(true) decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_bmi2") + Generate() } diff --git a/zstd/_generate/gen_fse.go b/zstd/_generate/gen_fse.go new file mode 100644 index 0000000000..95366e80e5 --- /dev/null +++ b/zstd/_generate/gen_fse.go @@ -0,0 +1,362 @@ +package main + +//go:generate go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd + +import ( + "flag" + + _ "github.com/klauspost/compress" + + . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/buildtags" + . "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +func main() { + flag.Parse() + + Constraint(buildtags.Not("appengine").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + Constraint(buildtags.Term("gc").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + + buildDtable := buildDtable{} + buildDtable.generateProcedure("buildDtable_asm") + Generate() +} + +const ( + errorCorruptedNormalizedCounter = 1 + errorNewStateTooBig = 2 + errorNewStateNoBits = 3 +) + +type buildDtable struct { + bmi2 bool + + // values used across all methods + actualTableLog reg.GPVirtual + tableSize reg.GPVirtual + highThreshold reg.GPVirtual + symbolNext reg.GPVirtual // array []uint16 + dt reg.GPVirtual // array []uint64 +} + +func (b *buildDtable) generateProcedure(name string) { + Package("github.com/klauspost/compress/zstd") + TEXT(name, 0, "func (s *fseDecoder, ctx *buildDtableAsmContext ) int") + Doc(name+" implements fseDecoder.buildDtable in asm", "") + Pragma("noescape") + + ctx := Dereference(Param("ctx")) + s := Dereference(Param("s")) + + Comment("Load values") + { + // tableSize = (1 << s.actualTableLog) + b.tableSize = GP64() + b.actualTableLog = GP64() + Load(s.Field("actualTableLog"), b.actualTableLog) + XORQ(b.tableSize, b.tableSize) + BTSQ(b.actualTableLog, b.tableSize) + + // symbolNext = &s.stateTable[0] + b.symbolNext = GP64() + Load(ctx.Field("stateTable"), b.symbolNext) + + // dt = &s.dt[0] + b.dt = GP64() + Load(ctx.Field("dt"), b.dt) + + // highThreshold = tableSize - 1 + b.highThreshold = GP64() + LEAQ(Mem{Base: b.tableSize, Disp: -1}, b.highThreshold) + } + + norm := GP64() + Load(ctx.Field("norm"), norm) + + symbolLen := GP64() + Load(s.Field("symbolLen"), symbolLen) + Comment("End load values") + + b.init(norm, symbolLen) + b.spread(norm, symbolLen) + b.buildTable() + + b.returnCode(0) +} + +func (b *buildDtable) init(norm, symbolLen reg.GPVirtual) { + Comment("Init, lay down lowprob symbols") + /* + for i, v := range s.norm[:s.symbolLen] { + if v == -1 { + s.dt[highThreshold].setAddBits(uint8(i)) + highThreshold-- + symbolNext[i] = 1 + } else { + symbolNext[i] = uint16(v) + } + } + */ + + i := New64() + JMP(LabelRef("init_main_loop_condition")) + Label("init_main_loop") + + v := GP64() + MOVWQSX(Mem{Base: norm, Index: i, Scale: 2}, v) + + CMPW(v.As16(), I16(-1)) + JNE(LabelRef("do_not_update_high_threshold")) + + { + // s.dt[highThreshold].setAddBits(uint8(i)) + MOVB(i.As8(), Mem{Base: b.dt, Index: b.highThreshold, Scale: 8, Disp: 1}) // set highThreshold*8 + 1 byte + // highThreshold-- + DECQ(b.highThreshold) + + // symbolNext[i] = 1 + MOVQ(U64(1), v) + } + + Label("do_not_update_high_threshold") + { + // symbolNext[i] = uint16(v) + MOVW(v.As16(), Mem{Base: b.symbolNext, Index: i, Scale: 2}) + + INCQ(i) + Label("init_main_loop_condition") + CMPQ(i, symbolLen) + JL(LabelRef("init_main_loop")) + } + + Label("init_end") +} + +func (b *buildDtable) spread(norm, symbolLen reg.GPVirtual) { + Comment("Spread symbols") + /* + tableMask := tableSize - 1 + step := tableStep(tableSize) + position := uint32(0) + for ss, v := range s.norm[:s.symbolLen] { + for i := 0; i < int(v); i++ { + s.dt[position].setAddBits(uint8(ss)) + position = (position + step) & tableMask + for position > highThreshold { + // lowprob area + position = (position + step) & tableMask + } + } + } + */ + step := GP64() + Comment("Calculate table step") + { + // tmp1 = tableSize >> 1 + tmp1 := Copy64(b.tableSize) + SHRQ(U8(1), tmp1) + + // tmp3 = tableSize >> 3 + tmp3 := Copy64(b.tableSize) + SHRQ(U8(3), tmp3) + + // step = tmp1 + tmp3 + 3 + LEAQ(Mem{Base: tmp1, Index: tmp3, Scale: 1, Disp: 3}, step) + } + + Comment("Fill add bits values") + + // tableMask = tableSize - 1 (tableSize is a pow of 2) + tableMask := GP64() + LEAQ(Mem{Base: b.tableSize, Disp: -1}, tableMask) + + // position := 0 + position := New64() + + // ss := 0 + ss := New64() + JMP(LabelRef("spread_main_loop_condition")) + Label("spread_main_loop") + { + i := New64() + v := GP64() + MOVWQSX(Mem{Base: norm, Index: ss, Scale: 2}, v) + JMP(LabelRef("spread_inner_loop_condition")) + Label("spread_inner_loop") + + { + // s.dt[position].setAddBits(uint8(ss)) + MOVB(ss.As8(), Mem{Base: b.dt, Index: position, Scale: 8, Disp: 1}) + + Label("adjust_position") + // position = (position + step) & tableMask + ADDQ(step, position) + ANDQ(tableMask, position) + + // for position > highThreshold { + // // lowprob area + // position = (position + step) & tableMask + // } + CMPQ(position, b.highThreshold) + JG(LabelRef("adjust_position")) + } + INCQ(i) + Label("spread_inner_loop_condition") + CMPQ(i, v) + JL(LabelRef("spread_inner_loop")) + } + + INCQ(ss) + Label("spread_main_loop_condition") + CMPQ(ss, symbolLen) + JL(LabelRef("spread_main_loop")) + + /* + if position != 0 { + // position must reach all cells once, otherwise normalizedCounter is incorrect + return errors.New("corrupted input (position != 0)") + } + */ + TESTQ(position, position) + { + JZ(LabelRef("spread_check_ok")) + b.returnError(errorCorruptedNormalizedCounter, position) + } + Label("spread_check_ok") +} + +func (b *buildDtable) buildTable() { + Comment("Build Decoding table") + /* + tableSize := uint16(1 << s.actualTableLog) + for u, v := range s.dt[:tableSize] { + symbol := v.addBits() + nextState := symbolNext[symbol] + symbolNext[symbol] = nextState + 1 + nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + s.dt[u&maxTableMask].setNBits(nBits) + newState := (nextState << nBits) - tableSize + if newState > tableSize { + return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + } + if newState == uint16(u) && nBits == 0 { + // Seems weird that this is possible with nbits > 0. + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + } + s.dt[u&maxTableMask].setNewState(newState) + } + */ + u := New64() + Label("build_table_main_table") + { + // v := s.dt[u] + // symbol := v.addBits() + symbol := GP64() + MOVBQZX(Mem{Base: b.dt, Index: u, Scale: 8, Disp: 1}, symbol) + + // nextState := symbolNext[symbol] + nextState := GP64() + ptr := Mem{Base: b.symbolNext, Index: symbol, Scale: 2} + MOVWQZX(ptr, nextState) + + // symbolNext[symbol] = nextState + 1 + { + tmp := GP64() + LEAQ(Mem{Base: nextState, Disp: 1}, tmp) + MOVW(tmp.As16(), ptr) + } + + // nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + nBits := reg.RCX // As we use nBits to shift + { + highBits := Copy64(nextState) + BSRQ(highBits, highBits) + + MOVQ(b.actualTableLog, nBits) + SUBQ(highBits, nBits) + } + + // newState := (nextState << nBits) - tableSize + newState := Copy64(nextState) + SHLQ(reg.CL, newState) + SUBQ(b.tableSize, newState) + + // s.dt[u&maxTableMask].setNBits(nBits) // sets byte #0 + // s.dt[u&maxTableMask].setNewState(newState) // sets word #1 (bytes #2 & #3) + { + MOVB(nBits.As8(), Mem{Base: b.dt, Index: u, Scale: 8}) + MOVW(newState.As16(), Mem{Base: b.dt, Index: u, Scale: 8, Disp: 2}) + } + + // if newState > tableSize { + // return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + // } + { + CMPQ(newState, b.tableSize) + JLE(LabelRef("build_table_check1_ok")) + + b.returnError(errorNewStateTooBig, newState, b.tableSize) + Label("build_table_check1_ok") + } + + // if newState == uint16(u) && nBits == 0 { + // // Seems weird that this is possible with nbits > 0. + // return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + // } + { + TESTB(nBits.As8(), nBits.As8()) + JNZ(LabelRef("build_table_check2_ok")) + CMPW(newState.As16(), u.As16()) + JNE(LabelRef("build_table_check2_ok")) + + b.returnError(errorNewStateNoBits, newState, u) + Label("build_table_check2_ok") + } + } + INCQ(u) + CMPQ(u, b.tableSize) + JL(LabelRef("build_table_main_table")) +} + +// returnCode sets function result and terminates the function. +func (b *buildDtable) returnCode(code int) { + a, err := ReturnIndex(0).Resolve() + if err != nil { + panic(err) + } + MOVQ(I32(code), a.Addr) + RET() +} + +// returnError sets error params and terminates function with given exit code. +func (b *buildDtable) returnError(code int, args ...reg.GPVirtual) { + ctx := Dereference(Param("ctx")) + + if len(args) > 0 { + Store(args[0], ctx.Field("errParam1")) + } + + if len(args) > 1 { + Store(args[1], ctx.Field("errParam2")) + } + + b.returnCode(code) +} + +func New64() reg.GPVirtual { + cnt := GP64() + XORQ(cnt, cnt) + + return cnt +} + +func Copy64(val reg.GPVirtual) reg.GPVirtual { + tmp := GP64() + MOVQ(val, tmp) + + return tmp +} diff --git a/zstd/fse_decoder.go b/zstd/fse_decoder.go index 23333b9692..2f8860a722 100644 --- a/zstd/fse_decoder.go +++ b/zstd/fse_decoder.go @@ -180,7 +180,6 @@ func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error { return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<> 3) - // println(s.norm[:s.symbolLen], s.symbolLen) return s.buildDtable() } @@ -269,68 +268,6 @@ func (s *fseDecoder) setRLE(symbol decSymbol) { s.dt[0] = symbol } -// buildDtable will build the decoding table. -func (s *fseDecoder) buildDtable() error { - tableSize := uint32(1 << s.actualTableLog) - highThreshold := tableSize - 1 - symbolNext := s.stateTable[:256] - - // Init, lay down lowprob symbols - { - for i, v := range s.norm[:s.symbolLen] { - if v == -1 { - s.dt[highThreshold].setAddBits(uint8(i)) - highThreshold-- - symbolNext[i] = 1 - } else { - symbolNext[i] = uint16(v) - } - } - } - // Spread symbols - { - tableMask := tableSize - 1 - step := tableStep(tableSize) - position := uint32(0) - for ss, v := range s.norm[:s.symbolLen] { - for i := 0; i < int(v); i++ { - s.dt[position].setAddBits(uint8(ss)) - position = (position + step) & tableMask - for position > highThreshold { - // lowprob area - position = (position + step) & tableMask - } - } - } - if position != 0 { - // position must reach all cells once, otherwise normalizedCounter is incorrect - return errors.New("corrupted input (position != 0)") - } - } - - // Build Decoding table - { - tableSize := uint16(1 << s.actualTableLog) - for u, v := range s.dt[:tableSize] { - symbol := v.addBits() - nextState := symbolNext[symbol] - symbolNext[symbol] = nextState + 1 - nBits := s.actualTableLog - byte(highBits(uint32(nextState))) - s.dt[u&maxTableMask].setNBits(nBits) - newState := (nextState << nBits) - tableSize - if newState > tableSize { - return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) - } - if newState == uint16(u) && nBits == 0 { - // Seems weird that this is possible with nbits > 0. - return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) - } - s.dt[u&maxTableMask].setNewState(newState) - } - } - return nil -} - // transform will transform the decoder table into a table usable for // decoding without having to apply the transformation while decoding. // The state will contain the base value and the number of bits to read. diff --git a/zstd/fse_decoder_amd64.go b/zstd/fse_decoder_amd64.go new file mode 100644 index 0000000000..e74df436cf --- /dev/null +++ b/zstd/fse_decoder_amd64.go @@ -0,0 +1,64 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "fmt" +) + +type buildDtableAsmContext struct { + // inputs + stateTable *uint16 + norm *int16 + dt *uint64 + + // outputs --- set by the procedure in the case of error; + // for interpretation please see the error handling part below + errParam1 uint64 + errParam2 uint64 +} + +// buildDtable_asm is an x86 assembly implementation of fseDecoder.buildDtable. +// Function returns non-zero exit code on error. +// go:noescape +func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int + +// please keep in sync with _generate/gen_fse.go +const ( + errorCorruptedNormalizedCounter = 1 + errorNewStateTooBig = 2 + errorNewStateNoBits = 3 +) + +// buildDtable will build the decoding table. +func (s *fseDecoder) buildDtable() error { + ctx := buildDtableAsmContext{ + stateTable: (*uint16)(&s.stateTable[0]), + norm: (*int16)(&s.norm[0]), + dt: (*uint64)(&s.dt[0]), + } + code := buildDtable_asm(s, &ctx) + + if code != 0 { + switch code { + case errorCorruptedNormalizedCounter: + position := ctx.errParam1 + return fmt.Errorf("corrupted input (position=%d, expected 0)", position) + + case errorNewStateTooBig: + newState := decSymbol(ctx.errParam1) + size := ctx.errParam2 + return fmt.Errorf("newState (%d) outside table size (%d)", newState, size) + + case errorNewStateNoBits: + newState := decSymbol(ctx.errParam1) + oldState := decSymbol(ctx.errParam2) + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, oldState) + + default: + return fmt.Errorf("buildDtable_asm returned unhandled nonzero code = %d", code) + } + } + return nil +} diff --git a/zstd/fse_decoder_amd64.s b/zstd/fse_decoder_amd64.s new file mode 100644 index 0000000000..da32b4420e --- /dev/null +++ b/zstd/fse_decoder_amd64.s @@ -0,0 +1,127 @@ +// Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT. + +//go:build !appengine && !noasm && gc && !noasm +// +build !appengine,!noasm,gc,!noasm + +// func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int +TEXT ·buildDtable_asm(SB), $0-24 + MOVQ ctx+8(FP), CX + MOVQ s+0(FP), DI + + // Load values + MOVBQZX 4098(DI), DX + XORQ AX, AX + BTSQ DX, AX + MOVQ (CX), BX + MOVQ 16(CX), SI + LEAQ -1(AX), R8 + MOVQ 8(CX), CX + MOVWQZX 4096(DI), DI + + // End load values + // Init, lay down lowprob symbols + XORQ R9, R9 + JMP init_main_loop_condition + +init_main_loop: + MOVWQSX (CX)(R9*2), R10 + CMPW R10, $-1 + JNE do_not_update_high_threshold + MOVB R9, 1(SI)(R8*8) + DECQ R8 + MOVQ $0x0000000000000001, R10 + +do_not_update_high_threshold: + MOVW R10, (BX)(R9*2) + INCQ R9 + +init_main_loop_condition: + CMPQ R9, DI + JL init_main_loop + + // Spread symbols + // Calculate table step + MOVQ AX, R9 + SHRQ $0x01, R9 + MOVQ AX, R10 + SHRQ $0x03, R10 + LEAQ 3(R9)(R10*1), R9 + + // Fill add bits values + LEAQ -1(AX), R10 + XORQ R11, R11 + XORQ R12, R12 + JMP spread_main_loop_condition + +spread_main_loop: + XORQ R13, R13 + MOVWQSX (CX)(R12*2), R14 + JMP spread_inner_loop_condition + +spread_inner_loop: + MOVB R12, 1(SI)(R11*8) + +adjust_position: + ADDQ R9, R11 + ANDQ R10, R11 + CMPQ R11, R8 + JG adjust_position + INCQ R13 + +spread_inner_loop_condition: + CMPQ R13, R14 + JL spread_inner_loop + INCQ R12 + +spread_main_loop_condition: + CMPQ R12, DI + JL spread_main_loop + TESTQ R11, R11 + JZ spread_check_ok + MOVQ ctx+8(FP), AX + MOVQ R11, 24(AX) + MOVQ $+1, ret+16(FP) + RET + +spread_check_ok: + // Build Decoding table + XORQ DI, DI + +build_table_main_table: + MOVBQZX 1(SI)(DI*8), CX + MOVWQZX (BX)(CX*2), R8 + LEAQ 1(R8), R9 + MOVW R9, (BX)(CX*2) + MOVQ R8, R9 + BSRQ R9, R9 + MOVQ DX, CX + SUBQ R9, CX + SHLQ CL, R8 + SUBQ AX, R8 + MOVB CL, (SI)(DI*8) + MOVW R8, 2(SI)(DI*8) + CMPQ R8, AX + JLE build_table_check1_ok + MOVQ ctx+8(FP), CX + MOVQ R8, 24(CX) + MOVQ AX, 32(CX) + MOVQ $+2, ret+16(FP) + RET + +build_table_check1_ok: + TESTB CL, CL + JNZ build_table_check2_ok + CMPW R8, DI + JNE build_table_check2_ok + MOVQ ctx+8(FP), AX + MOVQ R8, 24(AX) + MOVQ DI, 32(AX) + MOVQ $+3, ret+16(FP) + RET + +build_table_check2_ok: + INCQ DI + CMPQ DI, AX + JL build_table_main_table + MOVQ $+0, ret+16(FP) + RET diff --git a/zstd/fse_decoder_generic.go b/zstd/fse_decoder_generic.go new file mode 100644 index 0000000000..332e51fe44 --- /dev/null +++ b/zstd/fse_decoder_generic.go @@ -0,0 +1,72 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +package zstd + +import ( + "errors" + "fmt" +) + +// buildDtable will build the decoding table. +func (s *fseDecoder) buildDtable() error { + tableSize := uint32(1 << s.actualTableLog) + highThreshold := tableSize - 1 + symbolNext := s.stateTable[:256] + + // Init, lay down lowprob symbols + { + for i, v := range s.norm[:s.symbolLen] { + if v == -1 { + s.dt[highThreshold].setAddBits(uint8(i)) + highThreshold-- + symbolNext[i] = 1 + } else { + symbolNext[i] = uint16(v) + } + } + } + + // Spread symbols + { + tableMask := tableSize - 1 + step := tableStep(tableSize) + position := uint32(0) + for ss, v := range s.norm[:s.symbolLen] { + for i := 0; i < int(v); i++ { + s.dt[position].setAddBits(uint8(ss)) + position = (position + step) & tableMask + for position > highThreshold { + // lowprob area + position = (position + step) & tableMask + } + } + } + if position != 0 { + // position must reach all cells once, otherwise normalizedCounter is incorrect + return errors.New("corrupted input (position != 0)") + } + } + + // Build Decoding table + { + tableSize := uint16(1 << s.actualTableLog) + for u, v := range s.dt[:tableSize] { + symbol := v.addBits() + nextState := symbolNext[symbol] + symbolNext[symbol] = nextState + 1 + nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + s.dt[u&maxTableMask].setNBits(nBits) + newState := (nextState << nBits) - tableSize + if newState > tableSize { + return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + } + if newState == uint16(u) && nBits == 0 { + // Seems weird that this is possible with nbits > 0. + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + } + s.dt[u&maxTableMask].setNewState(newState) + } + } + return nil +}