Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

huff0: asm implementation of Decompress1X #596

Merged
merged 2 commits into from May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
208 changes: 205 additions & 3 deletions huff0/_generate/gen.go
Expand Up @@ -10,6 +10,7 @@ import (
_ "github.com/klauspost/compress"

. "github.com/mmcloughlin/avo/build"
"github.com/mmcloughlin/avo/gotypes"
. "github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
Expand All @@ -19,9 +20,19 @@ func main() {

ConstraintExpr("amd64,!appengine,!noasm,gc")

decompress := decompress4x{}
decompress.generateProcedure("decompress4x_main_loop_amd64")
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
{
decompress := decompress4x{}
decompress.generateProcedure("decompress4x_main_loop_amd64")
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
}

{
decompress := decompress1x{}
decompress.generateProcedure("decompress1x_main_loop_amd64")

decompress.bmi2 = true
decompress.generateProcedure("decompress1x_main_loop_bmi2")
}

Generate()
}
Expand Down Expand Up @@ -308,3 +319,194 @@ func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (
Label("skip_fill" + strconv.Itoa(id))
return
}

type bitReader struct {
in reg.GPVirtual
off reg.GPVirtual
value reg.GPVirtual
bitsRead reg.GPVirtual
id int
bmi2 bool
}

func (b *bitReader) uniqId() string {
b.id += 1
return strconv.Itoa(b.id)
}

func (b *bitReader) load(pointer gotypes.Component) {
b.in = GP64()
b.off = GP64()
b.value = GP64()
b.bitsRead = GP64()

Load(pointer.Field("in").Base(), b.in)
Load(pointer.Field("off"), b.off)
Load(pointer.Field("value"), b.value)
Load(pointer.Field("bitsRead"), b.bitsRead)
}

func (b *bitReader) store(pointer gotypes.Component) {
Store(b.off, pointer.Field("off"))
Store(b.value, pointer.Field("value"))
// Note: explicit As8(), without this avo reports: "could not deduce mov instruction"
Store(b.bitsRead.As8(), pointer.Field("bitsRead"))
}

func (b *bitReader) fillFast() {
label := "bitReader_fillFast_" + b.uniqId() + "_end"
CMPQ(b.bitsRead, U8(32))
JL(LabelRef(label))

SUBQ(U8(32), b.bitsRead)
SUBQ(U8(4), b.off)

tmp := GP64()
MOVL(Mem{Base: b.in, Index: b.off, Scale: 1}, tmp.As32())
if b.bmi2 {
SHLXQ(b.bitsRead, tmp, tmp)
} else {
MOVQ(b.bitsRead, reg.RCX)
SHLQ(reg.CL, tmp)
}
ORQ(tmp, b.value)
Label(label)
}

func (b *bitReader) peekTopBits(n reg.GPVirtual) reg.GPVirtual {
res := GP64()
if b.bmi2 {
SHRXQ(n, b.value, res)
} else {
MOVQ(n, reg.RCX)
MOVQ(b.value, res)
SHRQ(reg.CL, res)
}

return res
}

func (b *bitReader) advance(n reg.Register) {
ADDQ(n, b.bitsRead)
if b.bmi2 {
SHLXQ(n, b.value, b.value)
} else {
MOVQ(n, reg.RCX)
SHLQ(reg.CL, b.value)
}
}

type decompress1x struct {
bmi2 bool
}

func (d decompress1x) generateProcedure(name string) {
Package("github.com/klauspost/compress/huff0")
TEXT(name, 0, "func(ctx* decompress1xContext)")
Doc(name+" is an x86 assembler implementation of Decompress1X", "")
Pragma("noescape")

br := bitReader{}
br.bmi2 = d.bmi2

buffer := GP64()
bufferEnd := GP64() // the past-end address of buffer
dt := GP64()
peekBits := GP64()

{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("out"), buffer)

outCap := GP64()
Load(ctx.Field("outCap"), outCap)
CMPQ(outCap, U8(4))
JB(LabelRef("error_max_decoded_size_exeeded"))

LEAQ(Mem{Base: buffer, Index: outCap, Scale: 1}, bufferEnd)

// load bitReader struct
pbr := Dereference(ctx.Field("pbr"))
br.load(pbr)

Load(ctx.Field("tbl"), dt)
Load(ctx.Field("peekBits"), peekBits)
}

JMP(LabelRef("loop_condition"))

Label("main_loop")

out := reg.AX // Fixed, as we need an 8H part

Comment("Check if we have room for 4 bytes in the output buffer")
{
tmp := GP64()
LEAQ(Mem{Base: buffer, Disp: 4}, tmp)
CMPQ(tmp, bufferEnd)
JGE(LabelRef("error_max_decoded_size_exeeded"))
}

decompress := func(id int, out reg.Register) {
d.decompress(id, &br, peekBits, dt, out)
}

Comment("Decode 4 values")
br.fillFast()
decompress(0, out.As8L())
decompress(1, out.As8H())
BSWAPL(out.As32())

br.fillFast()
decompress(2, out.As8H())
decompress(3, out.As8L())
BSWAPL(out.As32())

Comment("Store the decoded values")
MOVL(out.As32(), Mem{Base: buffer})
ADDQ(U8(4), buffer)

Label("loop_condition")
CMPQ(br.off, U8(8))
JGE(LabelRef("main_loop"))

Comment("Update ctx structure")
{
// calculate decoded as current `out` - initial `out`
ctx := Dereference(Param("ctx"))
decoded := GP64()
tmp := GP64()
MOVQ(buffer, decoded)
Load(ctx.Field("out"), tmp)
SUBQ(tmp, decoded)
Store(decoded, ctx.Field("decoded"))

pbr := Dereference(ctx.Field("pbr"))
br.store(pbr)
}

RET()

Comment("Report error")
Label("error_max_decoded_size_exeeded")
{
ctx := Dereference(Param("ctx"))
tmp := GP64()
MOVQ(I64(-1), tmp)
Store(tmp, ctx.Field("decoded"))
}
}

func (d decompress1x) decompress(id int, br *bitReader, peekBits, dt reg.GPVirtual, out reg.Register) {
// v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
k := br.peekTopBits(peekBits)
v := reg.RCX // Fixed, as we need 8H part
MOVW(Mem{Base: dt, Index: k, Scale: 2}, v.As16())

// buf[id] = uint8(v.entry >> 8)
MOVB(v.As8H(), out)

// br.advance(uint8(v.entry))
MOVBQZX(v.As8L(), v)
br.advance(v)
}
102 changes: 0 additions & 102 deletions huff0/decompress.go
Expand Up @@ -236,108 +236,6 @@ func (d *Decoder) buffer() *[4][256]byte {
return &[4][256]byte{}
}

// Decompress1X will decompress a 1X encoded stream.
// The cap of the output buffer will be the maximum decompressed size.
// The length of the supplied input must match the end of a block exactly.
func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
if len(d.dt.single) == 0 {
return nil, errors.New("no table loaded")
}
if use8BitTables && d.actualTableLog <= 8 {
return d.decompress1X8Bit(dst, src)
}
var br bitReaderShifted
err := br.init(src)
if err != nil {
return dst, err
}
maxDecodedSize := cap(dst)
dst = dst[:0]

// Avoid bounds check by always having full sized table.
const tlSize = 1 << tableLogMax
const tlMask = tlSize - 1
dt := d.dt.single[:tlSize]

// Use temp table to avoid bound checks/append penalty.
bufs := d.buffer()
buf := &bufs[0]
var off uint8

for br.off >= 8 {
br.fillFast()
v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+0] = uint8(v.entry >> 8)

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+1] = uint8(v.entry >> 8)

// Refill
br.fillFast()

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+2] = uint8(v.entry >> 8)

v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
br.advance(uint8(v.entry))
buf[off+3] = uint8(v.entry >> 8)

off += 4
if off == 0 {
if len(dst)+256 > maxDecodedSize {
br.close()
d.bufs.Put(bufs)
return nil, ErrMaxDecodedSizeExceeded
}
dst = append(dst, buf[:]...)
}
}

if len(dst)+int(off) > maxDecodedSize {
d.bufs.Put(bufs)
br.close()
return nil, ErrMaxDecodedSizeExceeded
}
dst = append(dst, buf[:off]...)

// br < 8, so uint8 is fine
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
for bitsLeft > 0 {
br.fill()
if false && br.bitsRead >= 32 {
if br.off >= 4 {
v := br.in[br.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
br.value = (br.value << 32) | uint64(low)
br.bitsRead -= 32
br.off -= 4
} else {
for br.off > 0 {
br.value = (br.value << 8) | uint64(br.in[br.off-1])
br.bitsRead -= 8
br.off--
}
}
}
if len(dst) >= maxDecodedSize {
d.bufs.Put(bufs)
br.close()
return nil, ErrMaxDecodedSizeExceeded
}
v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
nBits := uint8(v.entry)
br.advance(nBits)
bitsLeft -= nBits
dst = append(dst, uint8(v.entry>>8))
}
d.bufs.Put(bufs)
return dst, br.close()
}

// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
// The cap of the output buffer will be the maximum decompressed size.
// The length of the supplied input must match the end of a block exactly.
Expand Down