From 364fb499c7e18697c3ef5d43b1d46dd06d746fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Mu=C5=82a?= Date: Thu, 12 May 2022 13:25:29 +0200 Subject: [PATCH] zstd: faster next state update in BMI2 version of decode Use the Go-code approach: use single getBits to obtain three bitfields. --- zstd/_generate/gen.go | 108 ++++++++++++++++- zstd/seqdec.go | 11 +- zstd/seqdec_amd64.s | 276 ++++++++++++++++++++---------------------- 3 files changed, 236 insertions(+), 159 deletions(-) diff --git a/zstd/_generate/gen.go b/zstd/_generate/gen.go index 2089402cea..23a26b21d4 100644 --- a/zstd/_generate/gen.go +++ b/zstd/_generate/gen.go @@ -289,12 +289,48 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute // Update states, max tablelog 28 { - Comment("Update Literal Length State") - o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable") - Comment("Update Match Length State") - o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable") - Comment("Update Offset State") - o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable") + if o.bmi2 { + // Get total number of bits (it is safe, as nBits is <= 9, thus 3*9 < 255) + total := GP64() + LEAQ(Mem{Base: llState, Index: mlState, Scale: 1}, total) + ADDQ(ofState, total) + MOVBQZX(total.As8(), total) // total = llState.As8() + mlState.As8() + ofState.As8() + + // Read `total` bits + bits := o.getBitsValue(name+"_getBits", total, brValue, brBitsRead) + + // Update states + Comment("Update Offset State") + { + nBits := ofState // Note: SHRXQ uses lower 6 bits of shift amount and BZHIQ lower 8 bits of count + lowBits := GP64() + BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1)) + SHRXQ(nBits, bits, bits) // bits >>= nBits + o.nextState(name+"_ofState", ofState, lowBits, "ofTable") + } + Comment("Update Match Length State") + { + nBits := mlState + lowBits := GP64() + BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1)) + SHRXQ(nBits, bits, bits) // lowBits >>= nBits + o.nextState(name+"_mlState", mlState, lowBits, "mlTable") + } + Comment("Update Literal Length State") + { + nBits := llState + lowBits := GP64() + BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1)) + o.nextState(name+"_llState", llState, lowBits, "llTable") + } + } else { + Comment("Update Literal Length State") + o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable") + Comment("Update Match Length State") + o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable") + Comment("Update Offset State") + o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable") + } } Label(name + "_skip_update") @@ -624,6 +660,39 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state) } +func (o options) nextState(name string, state, lowBits reg.GPVirtual, table string) { + DX := GP64() + if o.bmi2 { + tmp := GP64() + MOVQ(U32(16|(16<<8)), tmp) + BEXTRQ(tmp, state, DX) + } else { + MOVQ(state, DX) + SHRQ(U8(16), DX) + MOVWQZX(DX.As16(), DX) + } + + ADDQ(lowBits, DX) + + // Load table pointer + tablePtr := GP64() + Comment("Load ctx." + table) + ctx := Dereference(Param("ctx")) + tableA, err := ctx.Field(table).Base().Resolve() + if err != nil { + panic(err) + } + MOVQ(tableA.Addr, tablePtr) + + // Check if below tablelog + assert(func(ok LabelRef) { + CMPQ(DX, U32(512)) + JB(ok) + }) + // Load new state + MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state) +} + // getBits will return nbits bits from brValue. // If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned. func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual { @@ -649,6 +718,33 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, return BX } +// getBits will return nbits bits from brValue. +// If nbits == 0 then 0 is returned. +func (o options) getBitsValue(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual { + BX := GP64() + CX := reg.CL + if o.bmi2 { + LEAQ(Mem{Base: brBitsRead, Index: nBits, Scale: 1}, CX.As64()) + MOVQ(brValue, BX) + MOVQ(CX.As64(), brBitsRead) + ROLQ(CX, BX) + BZHIQ(nBits, BX, BX) + } else { + XORQ(BX, BX) + CMPQ(nBits, U8(0)) + JZ(LabelRef(name + "_get_bits_value_zero")) + MOVQ(brBitsRead, CX.As64()) + ADDQ(nBits, brBitsRead) + MOVQ(brValue, BX) + SHLQ(CX, BX) + MOVQ(nBits, CX.As64()) + NEGQ(CX.As64()) + SHRQ(CX, BX) + Label(name + "_get_bits_value_zero") + } + return BX +} + func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) { offset = GP64() MOVQ(moP, offset) diff --git a/zstd/seqdec.go b/zstd/seqdec.go index e80139dd9c..b8295c74a4 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -188,6 +188,7 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { } } } + // Add final literals copy(out[t:], s.literals) if debugDecoder { @@ -203,12 +204,11 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { // decode sequences from the stream with the provided history. func (s *sequenceDecs) decodeSync(hist []byte) error { - if true { - supported, err := s.decodeSyncSimple(hist) - if supported { - return err - } + supported, err := s.decodeSyncSimple(hist) + if supported { + return err } + br := s.br seqs := s.nSeqs startSize := len(s.out) @@ -396,6 +396,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error { ofState = ofTable[ofState.newState()&maxTableMask] } else { bits := br.get32BitsFast(nBits) + lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) llState = llTable[(llState.newState()+lowBits)&maxTableMask] diff --git a/zstd/seqdec_amd64.s b/zstd/seqdec_amd64.s index 9665833df1..212c6cac30 100644 --- a/zstd/seqdec_amd64.s +++ b/zstd/seqdec_amd64.s @@ -705,60 +705,55 @@ sequenceDecs_decode_bmi2_fill_2_end: MOVQ CX, (R9) // Fill bitreader for state updates - MOVQ R13, (SP) - MOVQ $0x00000808, CX - BEXTRQ CX, R8, R13 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_bmi2_skip_update - - // Update Literal Length State - MOVBQZX SI, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, SI, SI + MOVQ R13, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_bmi2_skip_update + LEAQ (SI)(DI*1), R14 + ADDQ R8, R14 + MOVBQZX R14, R14 LEAQ (DX)(R14*1), CX MOVQ AX, R15 MOVQ CX, DX ROLQ CL, R15 BZHIQ R14, R15, R15 - ADDQ R15, SI - // Load ctx.llTable + // Update Offset State + BZHIQ R8, R15, CX + SHRXQ R8, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable MOVQ ctx+16(FP), CX - MOVQ (CX), CX - MOVQ (CX)(SI*8), SI + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 // Update Match Length State - MOVBQZX DI, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, DI, DI - LEAQ (DX)(R14*1), CX - MOVQ AX, R15 - MOVQ CX, DX - ROLQ CL, R15 - BZHIQ R14, R15, R15 - ADDQ R15, DI + BZHIQ DI, R15, CX + SHRXQ DI, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, DI, DI + ADDQ CX, DI // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX MOVQ (CX)(DI*8), DI - // Update Offset State - MOVBQZX R8, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, R8, R8 - LEAQ (DX)(R14*1), CX - MOVQ AX, R15 - MOVQ CX, DX - ROLQ CL, R15 - BZHIQ R14, R15, R15 - ADDQ R15, R8 + // Update Literal Length State + BZHIQ SI, R15, CX + MOVQ $0x00001010, R14 + BEXTRQ R14, SI, SI + ADDQ CX, SI - // Load ctx.ofTable + // Load ctx.llTable MOVQ ctx+16(FP), CX - MOVQ 48(CX), CX - MOVQ (CX)(R8*8), R8 + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI sequenceDecs_decode_bmi2_skip_update: // Adjust offset @@ -965,60 +960,55 @@ sequenceDecs_decode_56_bmi2_fill_end: MOVQ CX, (R9) // Fill bitreader for state updates - MOVQ R13, (SP) - MOVQ $0x00000808, CX - BEXTRQ CX, R8, R13 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decode_56_bmi2_skip_update - - // Update Literal Length State - MOVBQZX SI, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, SI, SI + MOVQ R13, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_56_bmi2_skip_update + LEAQ (SI)(DI*1), R14 + ADDQ R8, R14 + MOVBQZX R14, R14 LEAQ (DX)(R14*1), CX MOVQ AX, R15 MOVQ CX, DX ROLQ CL, R15 BZHIQ R14, R15, R15 - ADDQ R15, SI - // Load ctx.llTable + // Update Offset State + BZHIQ R8, R15, CX + SHRXQ R8, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable MOVQ ctx+16(FP), CX - MOVQ (CX), CX - MOVQ (CX)(SI*8), SI + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 // Update Match Length State - MOVBQZX DI, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, DI, DI - LEAQ (DX)(R14*1), CX - MOVQ AX, R15 - MOVQ CX, DX - ROLQ CL, R15 - BZHIQ R14, R15, R15 - ADDQ R15, DI + BZHIQ DI, R15, CX + SHRXQ DI, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, DI, DI + ADDQ CX, DI // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX MOVQ (CX)(DI*8), DI - // Update Offset State - MOVBQZX R8, R14 - MOVQ $0x00001010, CX - BEXTRQ CX, R8, R8 - LEAQ (DX)(R14*1), CX - MOVQ AX, R15 - MOVQ CX, DX - ROLQ CL, R15 - BZHIQ R14, R15, R15 - ADDQ R15, R8 + // Update Literal Length State + BZHIQ SI, R15, CX + MOVQ $0x00001010, R14 + BEXTRQ R14, SI, SI + ADDQ CX, SI - // Load ctx.ofTable + // Load ctx.llTable MOVQ ctx+16(FP), CX - MOVQ 48(CX), CX - MOVQ (CX)(R8*8), R8 + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI sequenceDecs_decode_56_bmi2_skip_update: // Adjust offset @@ -2265,60 +2255,55 @@ sequenceDecs_decodeSync_bmi2_fill_2_end: MOVQ CX, 24(SP) // Fill bitreader for state updates - MOVQ R12, (SP) - MOVQ $0x00000808, CX - BEXTRQ CX, R8, R12 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decodeSync_bmi2_skip_update - - // Update Literal Length State - MOVBQZX SI, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, SI, SI + MOVQ R12, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R12 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_bmi2_skip_update + LEAQ (SI)(DI*1), R13 + ADDQ R8, R13 + MOVBQZX R13, R13 LEAQ (DX)(R13*1), CX MOVQ AX, R14 MOVQ CX, DX ROLQ CL, R14 BZHIQ R13, R14, R14 - ADDQ R14, SI - // Load ctx.llTable + // Update Offset State + BZHIQ R8, R14, CX + SHRXQ R8, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable MOVQ ctx+16(FP), CX - MOVQ (CX), CX - MOVQ (CX)(SI*8), SI + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 // Update Match Length State - MOVBQZX DI, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, DI, DI - LEAQ (DX)(R13*1), CX - MOVQ AX, R14 - MOVQ CX, DX - ROLQ CL, R14 - BZHIQ R13, R14, R14 - ADDQ R14, DI + BZHIQ DI, R14, CX + SHRXQ DI, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, DI, DI + ADDQ CX, DI // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX MOVQ (CX)(DI*8), DI - // Update Offset State - MOVBQZX R8, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, R8, R8 - LEAQ (DX)(R13*1), CX - MOVQ AX, R14 - MOVQ CX, DX - ROLQ CL, R14 - BZHIQ R13, R14, R14 - ADDQ R14, R8 + // Update Literal Length State + BZHIQ SI, R14, CX + MOVQ $0x00001010, R13 + BEXTRQ R13, SI, SI + ADDQ CX, SI - // Load ctx.ofTable + // Load ctx.llTable MOVQ ctx+16(FP), CX - MOVQ 48(CX), CX - MOVQ (CX)(R8*8), R8 + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI sequenceDecs_decodeSync_bmi2_skip_update: // Adjust offset @@ -3300,60 +3285,55 @@ sequenceDecs_decodeSync_safe_bmi2_fill_2_end: MOVQ CX, 24(SP) // Fill bitreader for state updates - MOVQ R12, (SP) - MOVQ $0x00000808, CX - BEXTRQ CX, R8, R12 - MOVQ ctx+16(FP), CX - CMPQ 96(CX), $0x00 - JZ sequenceDecs_decodeSync_safe_bmi2_skip_update - - // Update Literal Length State - MOVBQZX SI, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, SI, SI + MOVQ R12, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R12 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_safe_bmi2_skip_update + LEAQ (SI)(DI*1), R13 + ADDQ R8, R13 + MOVBQZX R13, R13 LEAQ (DX)(R13*1), CX MOVQ AX, R14 MOVQ CX, DX ROLQ CL, R14 BZHIQ R13, R14, R14 - ADDQ R14, SI - // Load ctx.llTable + // Update Offset State + BZHIQ R8, R14, CX + SHRXQ R8, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable MOVQ ctx+16(FP), CX - MOVQ (CX), CX - MOVQ (CX)(SI*8), SI + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 // Update Match Length State - MOVBQZX DI, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, DI, DI - LEAQ (DX)(R13*1), CX - MOVQ AX, R14 - MOVQ CX, DX - ROLQ CL, R14 - BZHIQ R13, R14, R14 - ADDQ R14, DI + BZHIQ DI, R14, CX + SHRXQ DI, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, DI, DI + ADDQ CX, DI // Load ctx.mlTable MOVQ ctx+16(FP), CX MOVQ 24(CX), CX MOVQ (CX)(DI*8), DI - // Update Offset State - MOVBQZX R8, R13 - MOVQ $0x00001010, CX - BEXTRQ CX, R8, R8 - LEAQ (DX)(R13*1), CX - MOVQ AX, R14 - MOVQ CX, DX - ROLQ CL, R14 - BZHIQ R13, R14, R14 - ADDQ R14, R8 + // Update Literal Length State + BZHIQ SI, R14, CX + MOVQ $0x00001010, R13 + BEXTRQ R13, SI, SI + ADDQ CX, SI - // Load ctx.ofTable + // Load ctx.llTable MOVQ ctx+16(FP), CX - MOVQ 48(CX), CX - MOVQ (CX)(R8*8), R8 + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI sequenceDecs_decodeSync_safe_bmi2_skip_update: // Adjust offset