From f8322cd2914757277018aa20318ca08ed5b157ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Mu=C5=82a?= Date: Fri, 17 Jun 2022 14:45:17 +0200 Subject: [PATCH] [skip ci] zstd: translate fseDecoder.buildDtable into asm --- zstd/_generate/gen.go | 1 + zstd/_generate/gen_fse.go | 425 ++++++++++++++++++++++++++++++++++++++ zstd/fse_decoder.go | 72 ++++++- zstd/fse_decoder_amd64.s | 136 ++++++++++++ 4 files changed, 633 insertions(+), 1 deletion(-) create mode 100644 zstd/_generate/gen_fse.go create mode 100644 zstd/fse_decoder_amd64.s diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go index 23a26b21d4..c7fe02b305 100644 --- a/zstd/_generate/gen.go +++ b/zstd/_generate/gen.go @@ -80,6 +80,7 @@ func main() { decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_amd64") decodeSync.setBMI2(true) decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_bmi2") + Generate() } diff --git a/zstd/_generate/gen_fse.go b/zstd/_generate/gen_fse.go new file mode 100644 index 0000000000..365f001498 --- /dev/null +++ b/zstd/_generate/gen_fse.go @@ -0,0 +1,425 @@ +package main + +//go:generate go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd + +import ( + "flag" + "fmt" + + _ "github.com/klauspost/compress" + + . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/buildtags" + . "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +func main() { + flag.Parse() + + Constraint(buildtags.Not("appengine").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + Constraint(buildtags.Term("gc").ToConstraint()) + Constraint(buildtags.Not("noasm").ToConstraint()) + + buildDtable := buildDtable{} + buildDtable.generateProcedure("buildDtable_asm") + Generate() +} + +const errorCorruptedNormalizedCounter = 1 +const errorNewStateTooBig = 2 +const errorNewStateNoBits = 3 + +type buildDtable struct { + bmi2 bool + + // values used across all methods + actualTableLog reg.GPVirtual + tableSize reg.GPVirtual + highThreshold reg.GPVirtual + symbolNext reg.GPVirtual // array []uint16 + dt reg.GPVirtual // array []uint64 + + // return values + errParam1 Mem + errParam2 Mem + + dumpId int +} + +func (b *buildDtable) generateProcedure(name string) { + Package("github.com/klauspost/compress/zstd") + TEXT(name, 0, "func (s *fseDecoder, ctx *buildDtableAsmContext ) (int, uint64, uint64)") + Doc(name+" implements fseDecoder.buildDtable in asm", "") + Pragma("noescape") + + { + param1, err := ReturnIndex(1).Resolve() + if err != nil { + panic(err) + } + + param2, err := ReturnIndex(2).Resolve() + if err != nil { + panic(err) + } + + b.errParam1 = param1.Addr + b.errParam2 = param2.Addr + } + + ctx := Dereference(Param("ctx")) + s := Dereference(Param("s")) + + Comment("Load values") + { + // tableSize = (1 << s.actualTableLog) + b.tableSize = GP64() + b.actualTableLog = GP64() + Load(s.Field("actualTableLog"), b.actualTableLog) + XORQ(b.tableSize, b.tableSize) + BTSQ(b.actualTableLog, b.tableSize) + + // symbolNext = &s.stateTable[0] + b.symbolNext = GP64() + Load(ctx.Field("stateTable"), b.symbolNext) + + // dt = &s.dt[0] + b.dt = GP64() + Load(ctx.Field("dt"), b.dt) + + // highThreshold = tableSize - 1 + b.highThreshold = GP64() + LEAQ(Mem{Base: b.tableSize, Disp: -1}, b.highThreshold) + } + Comment("End load values") + + norm := GP64() + Load(ctx.Field("norm"), norm) + + symbolLen := GP64() + Load(s.Field("symbolLen"), symbolLen) + + b.init(norm, symbolLen) + + b.spread(norm, symbolLen) + b.buildTable() + + returnCode := func(code int) { + a, err := ReturnIndex(0).Resolve() + if err != nil { + panic(err) + } + MOVQ(I32(code), a.Addr) + } + + returnCode(0) + RET() + + Label("error_corrupted_normalized_counter") + returnCode(errorCorruptedNormalizedCounter) + RET() + + Label("error_new_state_too_big") + returnCode(errorNewStateTooBig) + RET() + + Label("error_new_state_no_bits") + returnCode(errorNewStateNoBits) + RET() +} + +func (b *buildDtable) init(norm, symbolLen reg.GPVirtual) { + Comment("Init, lay down lowprob symbols") + /* + for i, v := range s.norm[:s.symbolLen] { + if v == -1 { + s.dt[highThreshold].setAddBits(uint8(i)) + highThreshold-- + symbolNext[i] = 1 + } else { + symbolNext[i] = uint16(v) + } + } + */ + + i := New64() + JMP(LabelRef("init_main_loop_condition")) + Label("init_main_loop") + + v := GP64() + MOVWQSX(Mem{Base: norm, Index: i, Scale: 2}, v) + + CMPW(v.As16(), I16(-1)) + JNE(LabelRef("do_not_update_high_threshold")) + + { + // s.dt[highThreshold].setAddBits(uint8(i)) + MOVB(i.As8(), Mem{Base: b.dt, Index: b.highThreshold, Scale: 8, Disp: 1}) // set highThreshold*8 + 1 byte + // highThreshold-- + DECQ(b.highThreshold) + + // symbolNext[i] = 1 + MOVQ(U64(1), v) + } + + Label("do_not_update_high_threshold") + { + // symbolNext[i] = uint16(v) + MOVW(v.As16(), Mem{Base: b.symbolNext, Index: i, Scale: 2}) + + INCQ(i) + Label("init_main_loop_condition") + CMPQ(i, symbolLen) + JL(LabelRef("init_main_loop")) + } + + Label("init_end") +} + +func (b *buildDtable) spread(norm, symbolLen reg.GPVirtual) { + Comment("Spread symbols") + /* + tableMask := tableSize - 1 + step := tableStep(tableSize) + position := uint32(0) + for ss, v := range s.norm[:s.symbolLen] { + for i := 0; i < int(v); i++ { + s.dt[position].setAddBits(uint8(ss)) + position = (position + step) & tableMask + for position > highThreshold { + // lowprob area + position = (position + step) & tableMask + } + } + } + */ + step := GP64() + Comment("Calculate table step") + { + // tmp1 = tableSize >> 1 + tmp1 := Copy64(b.tableSize) + SHRQ(U8(1), tmp1) + + // tmp3 = tableSize >> 3 + tmp3 := Copy64(b.tableSize) + SHRQ(U8(3), tmp3) + + // step = tmp1 + tmp3 + 3 + LEAQ(Mem{Base: tmp1, Index: tmp3, Scale: 1, Disp: 3}, step) + } + + Comment("Fill add bits values") + + // tableMask = tableSize - 1 (tableSize is a pow of 2) + tableMask := GP64() + LEAQ(Mem{Base: b.tableSize, Disp: -1}, tableMask) + + // position := 0 + position := New64() + + // ss := 0 + ss := New64() + JMP(LabelRef("spread_main_loop_condition")) + Label("spread_main_loop") + { + i := New64() + v := GP64() + MOVWQSX(Mem{Base: norm, Index: ss, Scale: 2}, v) + JMP(LabelRef("spread_inner_loop_condition")) + Label("spread_inner_loop") + + { + // s.dt[position].setAddBits(uint8(ss)) + MOVB(ss.As8(), Mem{Base: b.dt, Index: position, Scale: 8, Disp: 1}) + + Label("adjust_position") + // position = (position + step) & tableMask + ADDQ(step, position) + ANDQ(tableMask, position) + + // for position > highThreshold { + // // lowprob area + // position = (position + step) & tableMask + // } + CMPQ(position, b.highThreshold) + JG(LabelRef("adjust_position")) + } + INCQ(i) + Label("spread_inner_loop_condition") + CMPQ(i, v) + JL(LabelRef("spread_inner_loop")) + } + + INCQ(ss) + Label("spread_main_loop_condition") + CMPQ(ss, symbolLen) + JLE(LabelRef("spread_main_loop")) + + /* + if position != 0 { + // position must reach all cells once, otherwise normalizedCounter is incorrect + return errors.New("corrupted input (position != 0)") + } + */ + TESTQ(position, position) + { + JZ(LabelRef("spread_check_ok")) + MOVQ(position, b.errParam1) + JMP(LabelRef("error_corrupted_normalized_counter")) + } + Label("spread_check_ok") +} + +func (b *buildDtable) buildTable() { + Comment("Build Decoding table") + /* + tableSize := uint16(1 << s.actualTableLog) + for u, v := range s.dt[:tableSize] { + symbol := v.addBits() + nextState := symbolNext[symbol] + symbolNext[symbol] = nextState + 1 + nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + s.dt[u&maxTableMask].setNBits(nBits) + newState := (nextState << nBits) - tableSize + if newState > tableSize { + return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + } + if newState == uint16(u) && nBits == 0 { + // Seems weird that this is possible with nbits > 0. + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + } + s.dt[u&maxTableMask].setNewState(newState) + } + */ + u := New64() + Label("build_table_main_table") + { + // v := s.dt[u] + v := GP64() + MOVQ(Mem{Base: b.dt, Index: u, Scale: 8}, v) + + // symbol := v.addBits() + symbol := GP64() + MOVBQZX(Mem{Base: b.dt, Index: u, Scale: 8, Disp: 1}, symbol) + + // nextState := symbolNext[symbol] + nextState := GP64() + ptr := Mem{Base: b.symbolNext, Index: symbol, Scale: 2} + MOVWQZX(ptr, nextState) + + // symbolNext[symbol] = nextState + 1 + { + tmp := GP64() + LEAQ(Mem{Base: nextState, Disp: 1}, tmp) + MOVW(tmp.As16(), ptr) + } + + // nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + nBits := reg.RCX // As we use nBits to shift + { + highBits := GP64() + MOVWQZX(nextState.As16(), highBits) + BSRQ(highBits, highBits) + + MOVQ(b.actualTableLog, nBits) + SUBQ(highBits, nBits) + } + + // newState := (nextState << nBits) - tableSize + newState := GP64() + MOVQ(nextState, newState) + SHLQ(reg.CL, newState) + SUBQ(b.tableSize, newState) + + { + tmp := GP64() + MOVQ(nBits, tmp) + } + + // s.dt[u&maxTableMask].setNBits(nBits) // sets byte #0 + // s.dt[u&maxTableMask].setNewState(newState) // sets word #1 (bytes #2 & #3) + { + MOVB(nBits.As8(), Mem{Base: b.dt, Index: u, Scale: 8}) + MOVW(newState.As16(), Mem{Base: b.dt, Index: u, Scale: 8, Disp: 2}) + } + + // if newState > tableSize { + // return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + // } + { + CMPQ(newState, b.tableSize) + JLE(LabelRef("build_table_check1_ok")) + + MOVQ(newState, b.errParam1) + MOVQ(b.tableSize, b.errParam2) + JMP(LabelRef("error_new_state_too_big")) + Label("build_table_check1_ok") + } + + // if newState == uint16(u) && nBits == 0 { + // // Seems weird that this is possible with nbits > 0. + // return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + // } + { + TESTB(nBits.As8(), nBits.As8()) + JNZ(LabelRef("build_table_check2_ok")) + CMPW(newState.As16(), u.As16()) + JNE(LabelRef("build_table_check2_ok")) + MOVQ(newState, b.errParam1) + MOVQ(u, b.errParam2) + JMP(LabelRef("error_new_state_no_bits")) + Label("build_table_check2_ok") + } + } + INCQ(u) + CMPQ(u, b.tableSize) + JL(LabelRef("build_table_main_table")) +} + +func (b *buildDtable) dump(args ...reg.GPVirtual) { + Comment("dump START") + ctx := Dereference(Param("ctx")) + dumpIdx := GP64() + Load(ctx.Field("dumpIdx"), dumpIdx) + + length := GP64() + Load(ctx.Field("dump").Len(), length) + CMPQ(dumpIdx, length) + + label := fmt.Sprintf("dump_skip_%d", b.dumpId) + b.dumpId += 1 + + JGE(LabelRef(label)) + + ptr := GP64() + Load(ctx.Field("dump").Base(), ptr) + + if len(args) > 3 { + panic("up to 3 args for dump") + } + + for i, v := range args { + MOVQ(v, Mem{Base: ptr, Index: dumpIdx, Scale: 8, Disp: i * 8}) + } + + ADDQ(U8(3), dumpIdx) + Store(dumpIdx, ctx.Field("dumpIdx")) + Label(label) + Comment("dump END") +} + +func New64() reg.GPVirtual { + cnt := GP64() + XORQ(cnt, cnt) + + return cnt +} + +func Copy64(val reg.GPVirtual) reg.GPVirtual { + tmp := GP64() + MOVQ(val, tmp) + + return tmp +} diff --git a/zstd/fse_decoder.go b/zstd/fse_decoder.go index 23333b9692..28ee832008 100644 --- a/zstd/fse_decoder.go +++ b/zstd/fse_decoder.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "os" ) const ( @@ -180,7 +181,6 @@ func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error { return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<> 3) - // println(s.norm[:s.symbolLen], s.symbolLen) return s.buildDtable() } @@ -269,12 +269,81 @@ func (s *fseDecoder) setRLE(symbol decSymbol) { s.dt[0] = symbol } +type buildDtableAsmContext struct { + stateTable *uint16 + norm *int16 + dt *uint64 + dtLength int + + dumpIdx int + dump []uint64 +} + +// go:noescape +func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) (int, uint64, uint64) + +const errorCorruptedNormalizedCounter = 1 +const errorNewStateTooBig = 2 +const errorNewStateNoBits = 3 + +var buildDtableCount int + // buildDtable will build the decoding table. func (s *fseDecoder) buildDtable() error { tableSize := uint32(1 << s.actualTableLog) highThreshold := tableSize - 1 symbolNext := s.stateTable[:256] + useAsm := (os.Getenv("ASM") == "1") + + if useAsm { + ctx := buildDtableAsmContext{ + stateTable: (*uint16)(&s.stateTable[0]), + norm: (*int16)(&s.norm[0]), + dt: (*uint64)(&s.dt[0]), + dump: make([]uint64, 10000*3), // for debugging: remove + } + code, errParam1, errParam2 := buildDtable_asm(s, &ctx) + + // XXX: for debugging, remove + if false { + for i := 0; i < ctx.dumpIdx; i += 3 { + v0 := ctx.dump[i] + v1 := ctx.dump[i+1] + v2 := ctx.dump[i+2] + + _ = v0 + _ = v1 + _ = v2 + + fmt.Printf("v0=%d, v1=%d, v2=%d\n", v0, v1, v2) + } + } + + if code != 0 { + switch code { + case errorCorruptedNormalizedCounter: + position := errParam1 + return fmt.Errorf("corrupted input (position=%d, expected 0)", position) + + case errorNewStateTooBig: + newState := decSymbol(errParam1) + size := errParam2 + return fmt.Errorf("newState (%d) outside table size (%d)", newState, size) + + case errorNewStateNoBits: + newState := decSymbol(errParam1) + oldState := decSymbol(errParam2) + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, oldState) + + default: + return fmt.Errorf("buildDtable_asm returned unhandled nonzero code = %d", code) + } + + } + return nil + } + // Init, lay down lowprob symbols { for i, v := range s.norm[:s.symbolLen] { @@ -287,6 +356,7 @@ func (s *fseDecoder) buildDtable() error { } } } + // Spread symbols { tableMask := tableSize - 1 diff --git a/zstd/fse_decoder_amd64.s b/zstd/fse_decoder_amd64.s new file mode 100644 index 0000000000..26ee9a74b2 --- /dev/null +++ b/zstd/fse_decoder_amd64.s @@ -0,0 +1,136 @@ +// Code generated by command: go run gen_fse.go. DO NOT EDIT. + +//go:build !appengine && !noasm && gc && !noasm +// +build !appengine,!noasm,gc,!noasm + +// func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) (int, uint64, uint64) +TEXT ·buildDtable_asm(SB), $0-40 + MOVQ ctx+8(FP), CX + MOVQ s+0(FP), DI + + // Load values + MOVBQZX 4098(DI), DX + XORQ AX, AX + BTSQ DX, AX + MOVQ (CX), BX + MOVQ 16(CX), SI + LEAQ -1(AX), R8 + + // End load values + MOVQ 8(CX), CX + MOVWQZX 4096(DI), DI + + // Init, lay down lowprob symbols + XORQ R9, R9 + JMP init_main_loop_condition + +init_main_loop: + MOVWQSX (CX)(R9*2), R10 + CMPW R10, $-1 + JNE do_not_update_high_threshold + MOVB R9, 1(SI)(R8*8) + DECQ R8 + MOVQ $0x0000000000000001, R10 + +do_not_update_high_threshold: + MOVW R10, (BX)(R9*2) + INCQ R9 + +init_main_loop_condition: + CMPQ R9, DI + JL init_main_loop + + // Spread symbols + // Calculate table step + MOVQ AX, R9 + SHRQ $0x01, R9 + MOVQ AX, R10 + SHRQ $0x03, R10 + LEAQ 3(R9)(R10*1), R9 + + // Fill add bits values + LEAQ -1(AX), R10 + XORQ R11, R11 + XORQ R12, R12 + JMP spread_main_loop_condition + +spread_main_loop: + XORQ R13, R13 + MOVWQSX (CX)(R12*2), R14 + JMP spread_inner_loop_condition + +spread_inner_loop: + MOVB R12, 1(SI)(R11*8) + +adjust_position: + ADDQ R9, R11 + ANDQ R10, R11 + CMPQ R11, R8 + JG adjust_position + INCQ R13 + +spread_inner_loop_condition: + CMPQ R13, R14 + JL spread_inner_loop + INCQ R12 + +spread_main_loop_condition: + CMPQ R12, DI + JLE spread_main_loop + TESTQ R11, R11 + JZ spread_check_ok + MOVQ R11, ret1+24(FP) + JMP error_corrupted_normalized_counter + +spread_check_ok: + // Build Decoding table + XORQ DI, DI + +build_table_main_table: + MOVQ (SI)(DI*8), CX + MOVBQZX 1(SI)(DI*8), CX + MOVWQZX (BX)(CX*2), R8 + LEAQ 1(R8), R9 + MOVW R9, (BX)(CX*2) + MOVWQZX R8, R9 + BSRQ R9, R9 + MOVQ DX, CX + SUBQ R9, CX + SHLQ CL, R8 + SUBQ AX, R8 + MOVQ CX, R9 + MOVB CL, (SI)(DI*8) + MOVW R8, 2(SI)(DI*8) + CMPQ R8, AX + JLE build_table_check1_ok + MOVQ R8, ret1+24(FP) + MOVQ AX, ret2+32(FP) + JMP error_new_state_too_big + +build_table_check1_ok: + TESTB CL, CL + JNZ build_table_check2_ok + CMPW R8, DI + JNE build_table_check2_ok + MOVQ R8, ret1+24(FP) + MOVQ DI, ret2+32(FP) + JMP error_new_state_no_bits + +build_table_check2_ok: + INCQ DI + CMPQ DI, AX + JL build_table_main_table + MOVQ $+0, ret+16(FP) + RET + +error_corrupted_normalized_counter: + MOVQ $+1, ret+16(FP) + RET + +error_new_state_too_big: + MOVQ $+2, ret+16(FP) + RET + +error_new_state_no_bits: + MOVQ $+3, ret+16(FP) + RET