Skip to content

Commit

Permalink
huff0: add x86 asm implementation of Decompress1X
Browse files Browse the repository at this point in the history
  • Loading branch information
WojciechMula committed May 19, 2022
1 parent 3909335 commit b58e2de
Show file tree
Hide file tree
Showing 5 changed files with 592 additions and 106 deletions.
208 changes: 205 additions & 3 deletions huff0/_generate/gen.go
Expand Up @@ -10,6 +10,7 @@ import (
_ "github.com/klauspost/compress"

. "github.com/mmcloughlin/avo/build"
"github.com/mmcloughlin/avo/gotypes"
. "github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
Expand All @@ -19,9 +20,19 @@ func main() {

ConstraintExpr("amd64,!appengine,!noasm,gc")

decompress := decompress4x{}
decompress.generateProcedure("decompress4x_main_loop_amd64")
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
{
decompress := decompress4x{}
decompress.generateProcedure("decompress4x_main_loop_amd64")
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
}

{
decompress := decompress1x{}
decompress.generateProcedure("decompress1x_main_loop_amd64")

decompress.bmi2 = true
decompress.generateProcedure("decompress1x_main_loop_bmi2")
}

Generate()
}
Expand Down Expand Up @@ -308,3 +319,194 @@ func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (
Label("skip_fill" + strconv.Itoa(id))
return
}

type bitReader struct {
in reg.GPVirtual
off reg.GPVirtual
value reg.GPVirtual
bitsRead reg.GPVirtual
id int
bmi2 bool
}

func (b *bitReader) uniqId() string {
b.id += 1
return strconv.Itoa(b.id)
}

func (b *bitReader) load(pointer gotypes.Component) {
b.in = GP64()
b.off = GP64()
b.value = GP64()
b.bitsRead = GP64()

Load(pointer.Field("in").Base(), b.in)
Load(pointer.Field("off"), b.off)
Load(pointer.Field("value"), b.value)
Load(pointer.Field("bitsRead"), b.bitsRead)
}

func (b *bitReader) store(pointer gotypes.Component) {
Store(b.off, pointer.Field("off"))
Store(b.value, pointer.Field("value"))
// Note: explicit As8(), without this avo reports: "could not deduce mov instruction"
Store(b.bitsRead.As8(), pointer.Field("bitsRead"))
}

func (b *bitReader) fillFast() {
label := "bitReader_fillFast_" + b.uniqId() + "_end"
CMPQ(b.bitsRead, U8(32))
JL(LabelRef(label))

SUBQ(U8(32), b.bitsRead)
SUBQ(U8(4), b.off)

tmp := GP64()
MOVL(Mem{Base: b.in, Index: b.off, Scale: 1}, tmp.As32())
if b.bmi2 {
SHLXQ(b.bitsRead, tmp, tmp)
} else {
MOVQ(b.bitsRead, reg.RCX)
SHLQ(reg.CL, tmp)
}
ORQ(tmp, b.value)
Label(label)
}

func (b *bitReader) peekTopBits(n reg.GPVirtual) reg.GPVirtual {
res := GP64()
if b.bmi2 {
SHRXQ(n, b.value, res)
} else {
MOVQ(n, reg.RCX)
MOVQ(b.value, res)
SHRQ(reg.CL, res)
}

return res
}

func (b *bitReader) advance(n reg.Register) {
ADDQ(n, b.bitsRead)
if b.bmi2 {
SHLXQ(n, b.value, b.value)
} else {
MOVQ(n, reg.RCX)
SHLQ(reg.CL, b.value)
}
}

type decompress1x struct {
bmi2 bool
}

func (d decompress1x) generateProcedure(name string) {
Package("github.com/klauspost/compress/huff0")
TEXT(name, 0, "func(ctx* decompress1xContext)")
Doc(name+" is an x86 assembler implementation of Decompress1X", "")
Pragma("noescape")

br := bitReader{}
br.bmi2 = d.bmi2

buffer := GP64()
bufferEnd := GP64() // the past-end address of buffer
dt := GP64()
peekBits := GP64()

{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("out"), buffer)

outCap := GP64()
Load(ctx.Field("outCap"), outCap)
CMPQ(outCap, U8(4))
JB(LabelRef("error_max_decoded_size_exeeded"))

LEAQ(Mem{Base: buffer, Index: outCap, Scale: 1}, bufferEnd)

// load bitReader struct
pbr := Dereference(ctx.Field("pbr"))
br.load(pbr)

Load(ctx.Field("tbl"), dt)
Load(ctx.Field("peekBits"), peekBits)
}

JMP(LabelRef("loop_condition"))

Label("main_loop")

out := reg.AX // Fixed, as we need an 8H part

Comment("Check if we have room for 4 bytes in the output buffer")
{
tmp := GP64()
LEAQ(Mem{Base: buffer, Disp: 4}, tmp)
CMPQ(tmp, bufferEnd)
JGE(LabelRef("error_max_decoded_size_exeeded"))
}

decompress := func(id int, out reg.Register) {
d.decompress(id, &br, peekBits, dt, out)
}

Comment("Decode 4 values")
br.fillFast()
decompress(0, out.As8L())
decompress(1, out.As8H())
BSWAPL(out.As32())

br.fillFast()
decompress(2, out.As8H())
decompress(3, out.As8L())
BSWAPL(out.As32())

Comment("Store the decoded values")
MOVL(out.As32(), Mem{Base: buffer})
ADDQ(U8(4), buffer)

Label("loop_condition")
CMPQ(br.off, U8(8))
JGE(LabelRef("main_loop"))

Comment("Update ctx structure")
{
// calculate decoded as current `out` - initial `out`
ctx := Dereference(Param("ctx"))
decoded := GP64()
tmp := GP64()
MOVQ(buffer, decoded)
Load(ctx.Field("out"), tmp)
SUBQ(tmp, decoded)
Store(decoded, ctx.Field("decoded"))

pbr := Dereference(ctx.Field("pbr"))
br.store(pbr)
}

RET()

Comment("Report error")
Label("error_max_decoded_size_exeeded")
{
ctx := Dereference(Param("ctx"))
tmp := GP64()
MOVQ(I64(-1), tmp)
Store(tmp, ctx.Field("decoded"))
}
}

func (d decompress1x) decompress(id int, br *bitReader, peekBits, dt reg.GPVirtual, out reg.Register) {
// v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
k := br.peekTopBits(peekBits)
v := reg.RCX // Fixed, as we need 8H part
MOVW(Mem{Base: dt, Index: k, Scale: 2}, v.As16())

// buf[id] = uint8(v.entry >> 8)
MOVB(v.As8H(), out)

// br.advance(uint8(v.entry))
MOVBQZX(v.As8L(), v)
br.advance(v)
}
102 changes: 0 additions & 102 deletions huff0/decompress.go
Expand Up @@ -236,108 +236,6 @@ func (d *Decoder) buffer() *[4][256]byte {
return &[4][256]byte{}
}

// Decompress1X will decompress a 1X encoded stream.
// The cap of the output buffer will be the maximum decompressed size.
// The length of the supplied input must match the end of a block exactly.
func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
if len(d.dt.single) == 0 {
return nil, errors.New("no table loaded")
}
if use8BitTables && d.actualTableLog <= 8 {
return d.decompress1X8Bit(dst, src)
}
var br bitReaderShifted
err := br.init(src)
if err != nil {
return dst, err
}
maxDecodedSize := cap(dst)
dst = dst[:0]

// Avoid bounds check by always having full sized table.
const tlSize = 1 << tableLogMax
const tlMask = tlSize - 1
dt := d.dt.single[:tlSize]

// Use temp table to avoid bound checks/append penalty.
bufs := d.buffer()
buf := &bufs[0]
var off uint8

for br.off >= 8 {
br.fillFast()
v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+0] = uint8(v.entry >> 8)

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+1] = uint8(v.entry >> 8)

// Refill
br.fillFast()

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+2] = uint8(v.entry >> 8)

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+3] = uint8(v.entry >> 8)

off += 4
if off == 0 {
if len(dst)+256 > maxDecodedSize {
br.close()
d.bufs.Put(bufs)
return nil, ErrMaxDecodedSizeExceeded
}
dst = append(dst, buf[:]...)
}
}

if len(dst)+int(off) > maxDecodedSize {
d.bufs.Put(bufs)
br.close()
return nil, ErrMaxDecodedSizeExceeded
}
dst = append(dst, buf[:off]...)

// br < 8, so uint8 is fine
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
for bitsLeft > 0 {
br.fill()
if false && br.bitsRead >= 32 {
if br.off >= 4 {
v := br.in[br.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
br.value = (br.value << 32) | uint64(low)
br.bitsRead -= 32
br.off -= 4
} else {
for br.off > 0 {
br.value = (br.value << 8) | uint64(br.in[br.off-1])
br.bitsRead -= 8
br.off--
}
}
}
if len(dst) >= maxDecodedSize {
d.bufs.Put(bufs)
br.close()
return nil, ErrMaxDecodedSizeExceeded
}
v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
nBits := uint8(v.entry)
br.advance(nBits)
bitsLeft -= nBits
dst = append(dst, uint8(v.entry>>8))
}
d.bufs.Put(bufs)
return dst, br.close()
}

// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
// The cap of the output buffer will be the maximum decompressed size.
// The length of the supplied input must match the end of a block exactly.
Expand Down

0 comments on commit b58e2de

Please sign in to comment.