Skip to content

Commit

Permalink
[skip ci] zstd: translate fseDecoder.buildDtable into asm
Browse files Browse the repository at this point in the history
  • Loading branch information
WojciechMula committed May 20, 2022
1 parent 3909335 commit 07789a8
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 0 deletions.
339 changes: 339 additions & 0 deletions zstd/_generate/gen.go
Expand Up @@ -80,6 +80,11 @@ func main() {
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_amd64")
decodeSync.setBMI2(true)
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_bmi2")

if false {
buildDtable := buildDtable{}
buildDtable.generateProcedure("buildDtable_asm")
}
Generate()
}

Expand Down Expand Up @@ -1401,3 +1406,337 @@ func (d *decodeSync) generateProcedure(name string) {

d.decode.generateBody(name, d.execute.executeSingleTriple)
}

// ------------------------------------------------------------------------

const errorCorruptedNormalizedCounter = 1
const errorNewStateTooBig = 2
const errorNewStateNoBits = 3

type buildDtable struct {
bmi2 bool

// values used across all methods
actualTableLog reg.GPVirtual
tableSize reg.GPVirtual
highThreshold reg.GPVirtual
symbolNext reg.GPVirtual // array []uint16
dt reg.GPVirtual // array []uint32
}

func (b *buildDtable) generateProcedure(name string) {
Package("github.com/klauspost/compress/zstd")
TEXT(name, 0, "func (s *fseDecoder) (int, uint64, uint64)")
Doc(name+" implements fseDecoder.buildDtable in asm", "")
Pragma("noescape")

s := Dereference(Param("s"))

Comment("Load values")
{
// tableSize = (1 << s.actualTableLog)
b.tableSize = GP64()

actualTableLog := GP64()
Load(s.Field("actualTableLog"), actualTableLog)
XORQ(b.tableSize, b.tableSize)
BTSQ(actualTableLog, b.tableSize)
b.actualTableLog = GP64()
MOVQ(actualTableLog, b.actualTableLog)

// symbolNext = &s.stateTable[0]
b.symbolNext = GP64()
Load(s.Field("stateTable"), b.symbolNext)

// dt = &s.dt[0]
b.dt = GP64()
Load(s.Field("dt"), b.dt)

// highThreshold = tableSize - 1
b.highThreshold = GP64()
LEAQ(Mem{Base: b.tableSize, Disp: -1}, b.highThreshold)
}

norm := GP64()
Load(s.Field("norm"), norm)

symbolLen := GP64()
Load(s.Field("norm"), symbolLen)

b.init(norm, symbolLen)
b.spread(norm, symbolLen)
b.buildTable()

returnCode := func(code int) {
a, err := ReturnIndex(0).Resolve()
if err != nil {
panic(err)
}
MOVQ(I32(code), a.Addr)
}

returnCode(0)
RET()

Label("error_corrupted_normalized_counter")
returnCode(errorCorruptedNormalizedCounter)
RET()

Label("error_new_state_too_big")
returnCode(errorNewStateTooBig)
RET()

Label("error_new_state_no_bits")
returnCode(errorNewStateNoBits)
RET()
}

func (b *buildDtable) init(norm, symbolLen reg.GPVirtual) {
Comment("Init, lay down lowprob symbols")

/*
for i, v := range s.norm[:s.symbolLen] {
if v == -1 {
s.dt[highThreshold].setAddBits(uint8(i))
highThreshold--
symbolNext[i] = 1
} else {
symbolNext[i] = uint16(v)
}
}
*/

i := GP64()
Label("init_main_loop")

v := GP64()
MOVW(Mem{Base: norm, Index: i, Scale: 1}, v.As16())
CMPW(v.As16(), I16(-1))
JNE(LabelRef("do_not_update_high_threshold"))

{
// s.dt[highThreshold].setAddBits(uint8(i))
MOVB(i.As8(), Mem{Base: b.dt, Index: b.highThreshold, Scale: 4, Disp: 1}) // set highThreshold*4 + 1 byte
// highThreshold--
DECQ(b.highThreshold)

// symbolNext[i] = 1
MOVW(U16(1), Mem{Base: b.symbolNext, Index: i, Scale: 2})

INCQ(i)
CMPQ(i, symbolLen)
JL(LabelRef("init_main_loop"))
JMP(LabelRef("init_end"))
}

Label("do_not_update_high_threshold")
{
// symbolNext[i] = uint16(v)
MOVW(v.As16(), Mem{Base: b.symbolNext, Index: i, Scale: 2})

INCQ(i)
CMPQ(i, symbolLen)
JL(LabelRef("init_main_loop"))
}

Label("init_end")
}

func (b *buildDtable) spread(norm, symbolLen reg.GPVirtual) {
Comment("Spread symbols")
/*
tableMask := tableSize - 1
step := tableStep(tableSize)
position := uint32(0)
for ss, v := range s.norm[:s.symbolLen] {
for i := 0; i < int(v); i++ {
s.dt[position].setAddBits(uint8(ss))
position = (position + step) & tableMask
for position > highThreshold {
// lowprob area
position = (position + step) & tableMask
}
}
}
*/
step := GP64()
Comment("Calculate table step")
{
// tmp1 = tableSize >> 1
tmp1 := GP64()
MOVQ(b.tableSize, tmp1)
SHRQ(U8(1), tmp1)

// tmp3 = tableSize >> 3
tmp3 := GP64()
MOVQ(b.tableSize, tmp3)
SHRQ(U8(3), tmp3)

// step = tmp1 + tmp2 + 3
LEAQ(Mem{Base: tmp1, Index: tmp3, Scale: 1, Disp: 3}, step)
}

Comment("Fill add bits values")

// tableMask = tableSize - 1 (tableSize is a pow of 2)
tableMask := GP64()
LEAQ(Mem{Base: b.tableSize, Disp: -1}, tableMask)

// position := 0
position := GP64()
XORQ(position, position)

// ss := 0
ss := GP64()
XORQ(ss, ss)
Label("spread_main_loop")
{
v := GP64()
MOVW(Mem{Base: norm, Index: ss, Scale: 1}, v.As16())
Label("spread_inner_loop")
{
// s.dt[position].setAddBits(uint8(ss))
MOVB(ss.As8(), Mem{Base: b.dt, Index: position, Scale: 4, Disp: 1})

Label("adjust_position")
// position = (position + step) & tableMask
ADDQ(step, position)
ANDQ(tableMask, position)

// for position > highThreshold {
// // lowprob area
// position = (position + step) & tableMask
// }
CMPQ(position, b.highThreshold)
JG(LabelRef("adjust_position"))
}
DECQ(v)
JNZ(LabelRef("spread_inner_loop"))
}

INCQ(ss)
CMPQ(ss, symbolLen)
JL(LabelRef("spread_main_loop"))

/*
if position != 0 {
// position must reach all cells once, otherwise normalizedCounter is incorrect
return errors.New("corrupted input (position != 0)")
}
*/
TESTQ(position, position)
JNZ(LabelRef("error_corrupted_normalized_counter"))
}

func (b *buildDtable) buildTable() {
Comment("Build Decoding table")
/*
tableSize := uint16(1 << s.actualTableLog)
for u, v := range s.dt[:tableSize] {
symbol := v.addBits()
nextState := symbolNext[symbol]
symbolNext[symbol] = nextState + 1
nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
s.dt[u&maxTableMask].setNBits(nBits)
newState := (nextState << nBits) - tableSize
if newState > tableSize {
return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
}
if newState == uint16(u) && nBits == 0 {
// Seems weird that this is possible with nbits > 0.
return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
}
s.dt[u&maxTableMask].setNewState(newState)
}
*/
u := GP64()
XORQ(u, u)
Label("build_table_main_table")
{
// v := s.dt[u]
v := reg.RAX
MOVL(Mem{Base: b.dt, Index: u, Scale: 4}, v.As32())

// symbol := v.addBits()
symbol := GP64()
XORQ(symbol, symbol)
MOVB(v.As8H(), symbol.As8())

// nextState := symbolNext[symbol]
nextState := GP64()
ptr := Mem{Base: b.symbolNext, Index: symbol, Scale: 2} // XXX: 2?
MOVQ(ptr, nextState)

// symbolNext[symbol] = nextState + 1
{
tmp := GP64()
LEAQ(Mem{Base: nextState, Disp: 1}, tmp)
MOVQ(tmp, ptr)
}

// nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
nBits := reg.RCX // As we use nBits to shift
{
highBits := GP64()
MOVWQZX(nextState.As16(), highBits)
BSRQ(highBits, highBits)
DECQ(highBits) // XXX: needed?

MOVQ(b.actualTableLog, nBits)
SUBQ(highBits, nBits)
}

// newState := (nextState << nBits) - tableSize
newState := GP64()
MOVQ(nextState, newState)
SHLQ(reg.CL, newState)
SUBQ(b.tableSize, newState)

// s.dt[u&maxTableMask].setNBits(nBits) // sets byte #0
// s.dt[u&maxTableMask].setNewState(newState) // sets word #1 (bytes #2 & #3)
{
MOVB(nBits.As8(), Mem{Base: b.dt, Index: u, Scale: 4})
MOVW(newState.As16(), Mem{Base: b.dt, Index: u, Scale: 4, Disp: 2})
}

param1, err := ReturnIndex(1).Resolve()
if err != nil {
panic(err)
}

param2, err := ReturnIndex(2).Resolve()
if err != nil {
panic(err)
}
// if newState > tableSize {
// return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
// }
{
CMPQ(newState, b.tableSize)
JLE(LabelRef("build_table_check1_ok"))

MOVQ(newState, param1.Addr)
MOVQ(b.tableSize, param2.Addr)
JMP(LabelRef("error_new_state_too_big"))
Label("build_table_check1_ok")
}

// if newState == uint16(u) && nBits == 0 {
// // Seems weird that this is possible with nbits > 0.
// return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
// }
{
TESTB(nBits.As8(), nBits.As8())
JNZ(LabelRef("build_table_check2_ok"))
CMPW(newState.As16(), u.As16())
JNE(LabelRef("build_table_check2_ok"))
MOVQ(newState, param1.Addr)
MOVQ(u, param2.Addr)
JMP(LabelRef("error_new_state_no_bits"))
Label("build_table_check2_ok")
}
}
INCQ(u)
CMPQ(u, b.tableSize)
JL(LabelRef("build_table_main_table"))
}
3 changes: 3 additions & 0 deletions zstd/fse_decoder.go
Expand Up @@ -269,6 +269,9 @@ func (s *fseDecoder) setRLE(symbol decSymbol) {
s.dt[0] = symbol
}

// go:noescape
func buildDtable_asm(s *fseDecoder) (int, uint64, uint64)

// buildDtable will build the decoding table.
func (s *fseDecoder) buildDtable() error {
tableSize := uint32(1 << s.actualTableLog)
Expand Down

0 comments on commit 07789a8

Please sign in to comment.