diff --git a/huff0/_generate/gen.go b/huff0/_generate/gen.go index 04870d91dc..3a496e9d54 100644 --- a/huff0/_generate/gen.go +++ b/huff0/_generate/gen.go @@ -10,6 +10,7 @@ import ( _ "github.com/klauspost/compress" . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/gotypes" . "github.com/mmcloughlin/avo/operand" "github.com/mmcloughlin/avo/reg" ) @@ -19,9 +20,19 @@ func main() { ConstraintExpr("amd64,!appengine,!noasm,gc") - decompress := decompress4x{} - decompress.generateProcedure("decompress4x_main_loop_amd64") - decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64") + { + decompress := decompress4x{} + decompress.generateProcedure("decompress4x_main_loop_amd64") + decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64") + } + + { + decompress := decompress1x{} + decompress.generateProcedure("decompress1x_main_loop_amd64") + + decompress.bmi2 = true + decompress.generateProcedure("decompress1x_main_loop_bmi2") + } Generate() } @@ -308,3 +319,194 @@ func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) ( Label("skip_fill" + strconv.Itoa(id)) return } + +type bitReader struct { + in reg.GPVirtual + off reg.GPVirtual + value reg.GPVirtual + bitsRead reg.GPVirtual + id int + bmi2 bool +} + +func (b *bitReader) uniqId() string { + b.id += 1 + return strconv.Itoa(b.id) +} + +func (b *bitReader) load(pointer gotypes.Component) { + b.in = GP64() + b.off = GP64() + b.value = GP64() + b.bitsRead = GP64() + + Load(pointer.Field("in").Base(), b.in) + Load(pointer.Field("off"), b.off) + Load(pointer.Field("value"), b.value) + Load(pointer.Field("bitsRead"), b.bitsRead) +} + +func (b *bitReader) store(pointer gotypes.Component) { + Store(b.off, pointer.Field("off")) + Store(b.value, pointer.Field("value")) + // Note: explicit As8(), without this avo reports: "could not deduce mov instruction" + Store(b.bitsRead.As8(), pointer.Field("bitsRead")) +} + +func (b *bitReader) fillFast() { + label := "bitReader_fillFast_" + b.uniqId() + "_end" + CMPQ(b.bitsRead, U8(32)) + JL(LabelRef(label)) + + SUBQ(U8(32), b.bitsRead) + SUBQ(U8(4), b.off) + + tmp := GP64() + MOVL(Mem{Base: b.in, Index: b.off, Scale: 1}, tmp.As32()) + if b.bmi2 { + SHLXQ(b.bitsRead, tmp, tmp) + } else { + MOVQ(b.bitsRead, reg.RCX) + SHLQ(reg.CL, tmp) + } + ORQ(tmp, b.value) + Label(label) +} + +func (b *bitReader) peekTopBits(n reg.GPVirtual) reg.GPVirtual { + res := GP64() + if b.bmi2 { + SHRXQ(n, b.value, res) + } else { + MOVQ(n, reg.RCX) + MOVQ(b.value, res) + SHRQ(reg.CL, res) + } + + return res +} + +func (b *bitReader) advance(n reg.Register) { + ADDQ(n, b.bitsRead) + if b.bmi2 { + SHLXQ(n, b.value, b.value) + } else { + MOVQ(n, reg.RCX) + SHLQ(reg.CL, b.value) + } +} + +type decompress1x struct { + bmi2 bool +} + +func (d decompress1x) generateProcedure(name string) { + Package("github.com/klauspost/compress/huff0") + TEXT(name, 0, "func(ctx* decompress1xContext)") + Doc(name+" is an x86 assembler implementation of Decompress1X", "") + Pragma("noescape") + + br := bitReader{} + br.bmi2 = d.bmi2 + + buffer := GP64() + bufferEnd := GP64() // the past-end address of buffer + dt := GP64() + peekBits := GP64() + + { + ctx := Dereference(Param("ctx")) + Load(ctx.Field("out"), buffer) + + outCap := GP64() + Load(ctx.Field("outCap"), outCap) + CMPQ(outCap, U8(4)) + JB(LabelRef("error_max_decoded_size_exeeded")) + + LEAQ(Mem{Base: buffer, Index: outCap, Scale: 1}, bufferEnd) + + // load bitReader struct + pbr := Dereference(ctx.Field("pbr")) + br.load(pbr) + + Load(ctx.Field("tbl"), dt) + Load(ctx.Field("peekBits"), peekBits) + } + + JMP(LabelRef("loop_condition")) + + Label("main_loop") + + out := reg.AX // Fixed, as we need an 8H part + + Comment("Check if we have room for 4 bytes in the output buffer") + { + tmp := GP64() + LEAQ(Mem{Base: buffer, Disp: 4}, tmp) + CMPQ(tmp, bufferEnd) + JGE(LabelRef("error_max_decoded_size_exeeded")) + } + + decompress := func(id int, out reg.Register) { + d.decompress(id, &br, peekBits, dt, out) + } + + Comment("Decode 4 values") + br.fillFast() + decompress(0, out.As8L()) + decompress(1, out.As8H()) + BSWAPL(out.As32()) + + br.fillFast() + decompress(2, out.As8H()) + decompress(3, out.As8L()) + BSWAPL(out.As32()) + + Comment("Store the decoded values") + MOVL(out.As32(), Mem{Base: buffer}) + ADDQ(U8(4), buffer) + + Label("loop_condition") + CMPQ(br.off, U8(8)) + JGE(LabelRef("main_loop")) + + Comment("Update ctx structure") + { + // calculate decoded as current `out` - initial `out` + ctx := Dereference(Param("ctx")) + decoded := GP64() + tmp := GP64() + MOVQ(buffer, decoded) + Load(ctx.Field("out"), tmp) + SUBQ(tmp, decoded) + Store(decoded, ctx.Field("decoded")) + + pbr := Dereference(ctx.Field("pbr")) + br.store(pbr) + } + + RET() + + Comment("Report error") + Label("error_max_decoded_size_exeeded") + { + ctx := Dereference(Param("ctx")) + tmp := GP64() + MOVQ(I64(-1), tmp) + Store(tmp, ctx.Field("decoded")) + } +} + +func (d decompress1x) decompress(id int, br *bitReader, peekBits, dt reg.GPVirtual, out reg.Register) { + // v := dt[br.peekBitsFast(d.actualTableLog)&tlMask] + k := br.peekTopBits(peekBits) + v := reg.RCX // Fixed, as we need 8H part + MOVW(Mem{Base: dt, Index: k, Scale: 2}, v.As16()) + + // buf[id] = uint8(v.entry >> 8) + MOVB(v.As8H(), out) + + // br.advance(uint8(v.entry)) + MOVBQZX(v.As8L(), v) + br.advance(v) +} diff --git a/huff0/decompress.go b/huff0/decompress.go index 86ef85f724..4af9341773 100644 --- a/huff0/decompress.go +++ b/huff0/decompress.go @@ -236,108 +236,6 @@ func (d *Decoder) buffer() *[4][256]byte { return &[4][256]byte{} } -// Decompress1X will decompress a 1X encoded stream. -// The cap of the output buffer will be the maximum decompressed size. -// The length of the supplied input must match the end of a block exactly. -func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { - if len(d.dt.single) == 0 { - return nil, errors.New("no table loaded") - } - if use8BitTables && d.actualTableLog <= 8 { - return d.decompress1X8Bit(dst, src) - } - var br bitReaderShifted - err := br.init(src) - if err != nil { - return dst, err - } - maxDecodedSize := cap(dst) - dst = dst[:0] - - // Avoid bounds check by always having full sized table. - const tlSize = 1 << tableLogMax - const tlMask = tlSize - 1 - dt := d.dt.single[:tlSize] - - // Use temp table to avoid bound checks/append penalty. - bufs := d.buffer() - buf := &bufs[0] - var off uint8 - - for br.off >= 8 { - br.fillFast() - v := dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+0] = uint8(v.entry >> 8) - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+1] = uint8(v.entry >> 8) - - // Refill - br.fillFast() - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+2] = uint8(v.entry >> 8) - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+3] = uint8(v.entry >> 8) - - off += 4 - if off == 0 { - if len(dst)+256 > maxDecodedSize { - br.close() - d.bufs.Put(bufs) - return nil, ErrMaxDecodedSizeExceeded - } - dst = append(dst, buf[:]...) - } - } - - if len(dst)+int(off) > maxDecodedSize { - d.bufs.Put(bufs) - br.close() - return nil, ErrMaxDecodedSizeExceeded - } - dst = append(dst, buf[:off]...) - - // br < 8, so uint8 is fine - bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead - for bitsLeft > 0 { - br.fill() - if false && br.bitsRead >= 32 { - if br.off >= 4 { - v := br.in[br.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - br.value = (br.value << 32) | uint64(low) - br.bitsRead -= 32 - br.off -= 4 - } else { - for br.off > 0 { - br.value = (br.value << 8) | uint64(br.in[br.off-1]) - br.bitsRead -= 8 - br.off-- - } - } - } - if len(dst) >= maxDecodedSize { - d.bufs.Put(bufs) - br.close() - return nil, ErrMaxDecodedSizeExceeded - } - v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] - nBits := uint8(v.entry) - br.advance(nBits) - bitsLeft -= nBits - dst = append(dst, uint8(v.entry>>8)) - } - d.bufs.Put(bufs) - return dst, br.close() -} - // decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8. // The cap of the output buffer will be the maximum decompressed size. // The length of the supplied input must match the end of a block exactly. diff --git a/huff0/decompress_amd64.go b/huff0/decompress_amd64.go index 3415e5da22..671e630a84 100644 --- a/huff0/decompress_amd64.go +++ b/huff0/decompress_amd64.go @@ -2,12 +2,14 @@ // +build amd64,!appengine,!noasm,gc // This file contains the specialisation of Decoder.Decompress4X -// that uses an asm implementation of its main loop. +// and Decoder.Decompress1X that use an asm implementation of thir main loops. package huff0 import ( "errors" "fmt" + + "github.com/klauspost/compress/internal/cpuinfo" ) // decompress4x_main_loop_x86 is an x86 assembler implementation @@ -146,3 +148,81 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { } return dst, nil } + +// decompress4x_main_loop_x86 is an x86 assembler implementation +// of Decompress1X when tablelog > 8. +//go:noescape +func decompress1x_main_loop_amd64(ctx *decompress1xContext) + +// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation +// of Decompress1X when tablelog > 8. +//go:noescape +func decompress1x_main_loop_bmi2(ctx *decompress1xContext) + +type decompress1xContext struct { + pbr *bitReaderShifted + peekBits uint8 + out *byte + outCap int + tbl *dEntrySingle + decoded int +} + +// Error reported by asm implementations +const error_max_decoded_size_exeeded = -1 + +// Decompress1X will decompress a 1X encoded stream. +// The cap of the output buffer will be the maximum decompressed size. +// The length of the supplied input must match the end of a block exactly. +func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + var br bitReaderShifted + err := br.init(src) + if err != nil { + return dst, err + } + maxDecodedSize := cap(dst) + dst = dst[:maxDecodedSize] + + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + + if maxDecodedSize >= 4 { + ctx := decompress1xContext{ + pbr: &br, + out: &dst[0], + outCap: maxDecodedSize, + peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast() + tbl: &d.dt.single[0], + } + + if cpuinfo.HasBMI2() { + decompress1x_main_loop_bmi2(&ctx) + } else { + decompress1x_main_loop_amd64(&ctx) + } + if ctx.decoded == error_max_decoded_size_exeeded { + return nil, ErrMaxDecodedSizeExceeded + } + + dst = dst[:ctx.decoded] + } + + // br < 8, so uint8 is fine + bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead + for bitsLeft > 0 { + br.fill() + if len(dst) >= maxDecodedSize { + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] + nBits := uint8(v.entry) + br.advance(nBits) + bitsLeft -= nBits + dst = append(dst, uint8(v.entry>>8)) + } + return dst, br.close() +} diff --git a/huff0/decompress_amd64.s b/huff0/decompress_amd64.s index 06287f5685..5ff3f489f3 100644 --- a/huff0/decompress_amd64.s +++ b/huff0/decompress_amd64.s @@ -660,3 +660,204 @@ skip_fill1003: SHLQ $0x02, DX MOVQ DX, 64(AX) RET + +// func decompress1x_main_loop_amd64(ctx *decompress1xContext) +TEXT ·decompress1x_main_loop_amd64(SB), $0-8 + MOVQ ctx+0(FP), CX + MOVQ 16(CX), DX + MOVQ 24(CX), BX + CMPQ BX, $0x04 + JB error_max_decoded_size_exeeded + LEAQ (DX)(BX*1), BX + MOVQ (CX), SI + MOVQ (SI), R8 + MOVQ 24(SI), R9 + MOVQ 32(SI), R10 + MOVBQZX 40(SI), R11 + MOVQ 32(CX), SI + MOVBQZX 8(CX), DI + JMP loop_condition + +main_loop: + // Check if we have room for 4 bytes in the output buffer + LEAQ 4(DX), CX + CMPQ CX, BX + JGE error_max_decoded_size_exeeded + + // Decode 4 values + CMPQ R11, $0x20 + JL bitReader_fillFast_1_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), R12 + MOVQ R11, CX + SHLQ CL, R12 + ORQ R12, R10 + +bitReader_fillFast_1_end: + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + BSWAPL AX + CMPQ R11, $0x20 + JL bitReader_fillFast_2_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), R12 + MOVQ R11, CX + SHLQ CL, R12 + ORQ R12, R10 + +bitReader_fillFast_2_end: + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + BSWAPL AX + + // Store the decoded values + MOVL AX, (DX) + ADDQ $0x04, DX + +loop_condition: + CMPQ R9, $0x08 + JGE main_loop + + // Update ctx structure + MOVQ ctx+0(FP), AX + MOVQ DX, CX + MOVQ 16(AX), DX + SUBQ DX, CX + MOVQ CX, 40(AX) + MOVQ (AX), AX + MOVQ R9, 24(AX) + MOVQ R10, 32(AX) + MOVB R11, 40(AX) + RET + + // Report error +error_max_decoded_size_exeeded: + MOVQ ctx+0(FP), AX + MOVQ $-1, CX + MOVQ CX, 40(AX) + +// func decompress1x_main_loop_bmi2(ctx *decompress1xContext) +// Requires: BMI2 +TEXT ·decompress1x_main_loop_bmi2(SB), $0-8 + MOVQ ctx+0(FP), CX + MOVQ 16(CX), DX + MOVQ 24(CX), BX + CMPQ BX, $0x04 + JB error_max_decoded_size_exeeded + LEAQ (DX)(BX*1), BX + MOVQ (CX), SI + MOVQ (SI), R8 + MOVQ 24(SI), R9 + MOVQ 32(SI), R10 + MOVBQZX 40(SI), R11 + MOVQ 32(CX), SI + MOVBQZX 8(CX), DI + JMP loop_condition + +main_loop: + // Check if we have room for 4 bytes in the output buffer + LEAQ 4(DX), CX + CMPQ CX, BX + JGE error_max_decoded_size_exeeded + + // Decode 4 values + CMPQ R11, $0x20 + JL bitReader_fillFast_1_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), CX + SHLXQ R11, CX, CX + ORQ CX, R10 + +bitReader_fillFast_1_end: + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + BSWAPL AX + CMPQ R11, $0x20 + JL bitReader_fillFast_2_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), CX + SHLXQ R11, CX, CX + ORQ CX, R10 + +bitReader_fillFast_2_end: + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + BSWAPL AX + + // Store the decoded values + MOVL AX, (DX) + ADDQ $0x04, DX + +loop_condition: + CMPQ R9, $0x08 + JGE main_loop + + // Update ctx structure + MOVQ ctx+0(FP), AX + MOVQ DX, CX + MOVQ 16(AX), DX + SUBQ DX, CX + MOVQ CX, 40(AX) + MOVQ (AX), AX + MOVQ R9, 24(AX) + MOVQ R10, 32(AX) + MOVB R11, 40(AX) + RET + + // Report error +error_max_decoded_size_exeeded: + MOVQ ctx+0(FP), AX + MOVQ $-1, CX + MOVQ CX, 40(AX) diff --git a/huff0/decompress_generic.go b/huff0/decompress_generic.go index 126b4d68a9..4f6f37cb2c 100644 --- a/huff0/decompress_generic.go +++ b/huff0/decompress_generic.go @@ -191,3 +191,105 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { } return dst, nil } + +// Decompress1X will decompress a 1X encoded stream. +// The cap of the output buffer will be the maximum decompressed size. +// The length of the supplied input must match the end of a block exactly. +func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + if use8BitTables && d.actualTableLog <= 8 { + return d.decompress1X8Bit(dst, src) + } + var br bitReaderShifted + err := br.init(src) + if err != nil { + return dst, err + } + maxDecodedSize := cap(dst) + dst = dst[:0] + + // Avoid bounds check by always having full sized table. + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + dt := d.dt.single[:tlSize] + + // Use temp table to avoid bound checks/append penalty. + bufs := d.buffer() + buf := &bufs[0] + var off uint8 + + for br.off >= 8 { + br.fillFast() + v := dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+0] = uint8(v.entry >> 8) + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+1] = uint8(v.entry >> 8) + + // Refill + br.fillFast() + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+2] = uint8(v.entry >> 8) + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+3] = uint8(v.entry >> 8) + + off += 4 + if off == 0 { + if len(dst)+256 > maxDecodedSize { + br.close() + d.bufs.Put(bufs) + return nil, ErrMaxDecodedSizeExceeded + } + dst = append(dst, buf[:]...) + } + } + + if len(dst)+int(off) > maxDecodedSize { + d.bufs.Put(bufs) + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + dst = append(dst, buf[:off]...) + + // br < 8, so uint8 is fine + bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead + for bitsLeft > 0 { + br.fill() + if false && br.bitsRead >= 32 { + if br.off >= 4 { + v := br.in[br.off-4:] + v = v[:4] + low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) + br.value = (br.value << 32) | uint64(low) + br.bitsRead -= 32 + br.off -= 4 + } else { + for br.off > 0 { + br.value = (br.value << 8) | uint64(br.in[br.off-1]) + br.bitsRead -= 8 + br.off-- + } + } + } + if len(dst) >= maxDecodedSize { + d.bufs.Put(bufs) + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] + nBits := uint8(v.entry) + br.advance(nBits) + bitsLeft -= nBits + dst = append(dst, uint8(v.entry>>8)) + } + d.bufs.Put(bufs) + return dst, br.close() +}