Skip to content

Commit

Permalink
[skip ci] WiP: avo implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
WojciechMula committed May 15, 2022
1 parent bf3e875 commit 5f366e2
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 14 deletions.
179 changes: 178 additions & 1 deletion huff0/_generate/gen.go
Expand Up @@ -10,6 +10,7 @@ import (
_ "github.com/klauspost/compress"

. "github.com/mmcloughlin/avo/build"
"github.com/mmcloughlin/avo/gotypes"
. "github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
Expand All @@ -19,7 +20,7 @@ func main() {

ConstraintExpr("amd64,!appengine,!noasm,gc")

{
if true {
decompress := decompress4x{}
decompress.generateProcedure("decompress4x_main_loop_amd64")
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
Expand All @@ -28,6 +29,9 @@ func main() {
{
decompress := decompress1x{}
decompress.generateProcedure("decompress1x_main_loop_amd64")

decompress.bmi2 = true
decompress.generateProcedure("decompress1x_main_loop_bmi2")
}

Generate()
Expand Down Expand Up @@ -316,7 +320,78 @@ func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (
return
}

type bitReader struct {
bmi2 bool
in reg.GPVirtual
off reg.GPVirtual
value reg.GPVirtual
bitsRead reg.GPVirtual

id int
}

func (b *bitReader) uniqId() string {
b.id += 1
return strconv.Itoa(b.id)
}

func (b *bitReader) load(pointer gotypes.Component) {
b.in = GP64()
b.off = GP64()
b.value = GP64()
b.bitsRead = GP64()

Load(pointer.Field("in").Base(), b.in)
Load(pointer.Field("off"), b.off)
Load(pointer.Field("value"), b.value)
Load(pointer.Field("bitsRead"), b.bitsRead)
}

func (b *bitReader) store(pointer gotypes.Component) {
Store(b.off, pointer.Field("off"))
Store(b.value, pointer.Field("value"))
// Note: explicit As8(), without this avo reports: "could not deduce mov instruction"
Store(b.bitsRead.As8(), pointer.Field("bitsRead"))
}

func (b *bitReader) fillFast() {
label := "bitReader_fillFast_" + b.uniqId() + "_end"
CMPQ(b.bitsRead, U8(32))
JL(LabelRef(label))

tmp := GP64()
MOVL(Mem{Base: b.in, Index: b.off, Scale: 1}, tmp.As32())
ORQ(tmp, b.value)
SUBQ(U8(32), b.bitsRead)
SUBQ(U8(4), b.off)
Label(label)
}

func (b *bitReader) peekTopBits(n reg.GPVirtual) reg.GPVirtual {
res := GP64()
if b.bmi2 {
SHRXQ(n, b.value, res)
} else {
MOVQ(n, reg.RCX)
MOVQ(b.value, res)
SHRQ(reg.CL, res)
}

return res
}

func (b *bitReader) advance(n reg.Register) {
ADDQ(n, b.bitsRead)
if b.bmi2 {
SHLXQ(n, b.value, b.value)
} else {
MOVQ(n, reg.RCX)
SHLQ(reg.CL, b.value)
}
}

type decompress1x struct {
bmi2 bool
}

func (d decompress1x) generateProcedure(name string) {
Expand All @@ -325,5 +400,107 @@ func (d decompress1x) generateProcedure(name string) {
Doc(name+" is an x86 assembler implementation of Decompress1X", "")
Pragma("noescape")

br := bitReader{}
br.bmi2 = d.bmi2

buffer := GP64()
bufferEnd := GP64() // the past-end address of buffer
dt := GP64()
peekBits := GP64()

{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("out"), buffer)

outCap := GP64()
Load(ctx.Field("outCap"), outCap)
CMPQ(outCap, U8(4))
JB(LabelRef("error_max_decoded_size_exeeded"))

LEAQ(Mem{Base: buffer, Index: outCap, Scale: 1}, bufferEnd)

// load bitReader struct
pbr := Dereference(ctx.Field("pbr"))
br.load(pbr)

Load(ctx.Field("tbl"), dt)
Load(ctx.Field("peekBits"), peekBits)
}

JMP(LabelRef("loop_condition"))

Label("main_loop")

out := reg.AX // Fixed, as we need an 8H part

Comment("Check if we have room for 4 bytes in the output buffer")
{
tmp := GP64()
LEAQ(Mem{Base: buffer, Disp: 4}, tmp)
CMPQ(tmp, bufferEnd)
JGE(LabelRef("error_max_decoded_size_exeeded"))
}

decompress := func(id int, out reg.Register) {
d.decompress(id, &br, peekBits, dt, out)
}

Comment("Decode 4 values")
br.fillFast()
decompress(0, out.As8L())
decompress(1, out.As8H())
BSWAPL(out.As32())

br.fillFast()
decompress(2, out.As8H())
decompress(3, out.As8L())
BSWAPL(out.As32())

Comment("Store the decoded values")
MOVL(out.As32(), Mem{Base: buffer})
ADDQ(U8(4), buffer)

Label("loop_condition")
CMPQ(br.off, U8(8))
JGE(LabelRef("main_loop"))

Comment("Update ctx structure")
{
// calculate decoded as current `out` - initial `out`
ctx := Dereference(Param("ctx"))
decoded := GP64()
tmp := GP64()
MOVQ(buffer, decoded)
Load(ctx.Field("out"), tmp)
SUBQ(tmp, decoded)
Store(decoded, ctx.Field("decoded"))

pbr := Dereference(ctx.Field("pbr"))
br.store(pbr)
}

RET()

Comment("Report error")
Label("error_max_decoded_size_exeeded")
{
ctx := Dereference(Param("ctx"))
tmp := GP64()
MOVQ(I64(-1), tmp)
Store(tmp, ctx.Field("decoded"))
}
}

func (d decompress1x) decompress(id int, br *bitReader, peekBits, dt reg.GPVirtual, out reg.Register) {
// v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
k := br.peekTopBits(peekBits)
v := reg.RCX // Fixed, as we need 8H part
MOVQ(Mem{Base: dt, Index: k, Scale: 1}, v)

// buf[id] = uint8(v.entry >> 8)
MOVB(v.As8H(), out)

// br.advance(uint8(v.entry))
MOVBQZX(v.As8L(), v)
br.advance(v)
}
30 changes: 17 additions & 13 deletions huff0/decompress_amd64.go
Expand Up @@ -159,9 +159,10 @@ type decompress1xContext struct {
outCap int
tbl *dEntrySingle
decoded int
limit *byte
}

const error_max_decoded_size_exeeded = -1

// Decompress1X will decompress a 1X encoded stream.
// The cap of the output buffer will be the maximum decompressed size.
// The length of the supplied input must match the end of a block exactly.
Expand All @@ -178,24 +179,27 @@ func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
return dst, err
}
maxDecodedSize := cap(dst)
dst = dst[:maxDecodedSize]

const tlSize = 1 << tableLogMax
const tlMask = tlSize - 1

ctx := decompress1xContext{
pbr: &br,
out: &dst[0],
outCap: maxDecodedSize,
tbl: &d.dt.single[0],
/* XXX: complete */
}
if maxDecodedSize >= 4 {
ctx := decompress1xContext{
pbr: &br,
out: &dst[0],
outCap: maxDecodedSize,
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
tbl: &d.dt.single[0],
}

decompress1x_main_loop_amd64(&ctx)
if ctx.decoded == -1 {
return nil, ErrMaxDecodedSizeExceeded
}
decompress1x_main_loop_amd64(&ctx)
if ctx.decoded == error_max_decoded_size_exeeded {
return nil, ErrMaxDecodedSizeExceeded
}

dst = dst[:ctx.decoded]
dst = dst[:ctx.decoded]
}

// br < 8, so uint8 is fine
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
Expand Down

0 comments on commit 5f366e2

Please sign in to comment.