Skip to content

Commit

Permalink
wazevo(amd64): lower Fabs, Fneg, Fadd, Fsub, Fdiv, Fmul, Fsqrt (tetra…
Browse files Browse the repository at this point in the history
…telabs#1965)

Signed-off-by: Edoardo Vacchi <evacchi@users.noreply.github.com>
Co-authored-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
  • Loading branch information
evacchi and mathetake committed Jan 27, 2024
1 parent 43c2b2a commit bc5917a
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 2 deletions.
2 changes: 2 additions & 0 deletions internal/engine/wazevo/backend/isa/amd64/instr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,7 @@ var defKinds = [instrMax]defKind{
imm: defKindOp2,
unaryRmR: defKindOp2,
xmmUnaryRmR: defKindOp2,
xmmRmR: defKindNone,
mov64MR: defKindOp2,
movsxRmR: defKindOp2,
movzxRmR: defKindOp2,
Expand Down Expand Up @@ -1641,6 +1642,7 @@ var useKinds = [instrMax]useKind{
imm: useKindNone,
unaryRmR: useKindOp1,
xmmUnaryRmR: useKindOp1,
xmmRmR: useKindOp1Op2Reg,
mov64MR: useKindOp1,
movzxRmR: useKindOp1,
movsxRmR: useKindOp1,
Expand Down
1 change: 1 addition & 0 deletions internal/engine/wazevo/backend/isa/amd64/instr_encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ func (i *instruction) encode(c backend.Compiler) (needsLabelResolution bool) {

case div:
panic("TODO")

case mulHi:
var prefix legacyPrefixes
rex := rexInfo(0)
Expand Down
118 changes: 118 additions & 0 deletions internal/engine/wazevo/backend/isa/amd64/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ func (m *machine) LowerInstr(instr *ssa.Instruction) {
m.lowerCtz(instr)
case ssa.OpcodePopcnt:
m.lowerUnaryRmR(instr, unaryRmROpcodePopcnt)
case ssa.OpcodeFadd, ssa.OpcodeFsub, ssa.OpcodeFmul, ssa.OpcodeFdiv:
m.lowerXmmRmR(instr)
case ssa.OpcodeFabs:
m.lowerFabsFneg(instr)
case ssa.OpcodeFneg:
m.lowerFabsFneg(instr)
case ssa.OpcodeSqrt:
m.lowerSqrt(instr)
case ssa.OpcodeUndefined:
m.insert(m.allocateInstr().asUD2())
case ssa.OpcodeExitWithCode:
Expand Down Expand Up @@ -682,6 +690,116 @@ func (m *machine) lowerShiftR(si *ssa.Instruction, op shiftROp) {
m.copyTo(tmpDst, rd)
}

func (m *machine) lowerXmmRmR(instr *ssa.Instruction) {
x, y := instr.Arg2()
if !x.Type().IsFloat() {
panic("BUG?")
}
_64 := x.Type().Bits() == 64

var op sseOpcode
if _64 {
switch instr.Opcode() {
case ssa.OpcodeFadd:
op = sseOpcodeAddsd
case ssa.OpcodeFsub:
op = sseOpcodeSubsd
case ssa.OpcodeFmul:
op = sseOpcodeMulsd
case ssa.OpcodeFdiv:
op = sseOpcodeDivsd
default:
panic("BUG")
}
} else {
switch instr.Opcode() {
case ssa.OpcodeFadd:
op = sseOpcodeAddss
case ssa.OpcodeFsub:
op = sseOpcodeSubss
case ssa.OpcodeFmul:
op = sseOpcodeMulss
case ssa.OpcodeFdiv:
op = sseOpcodeDivss
default:
panic("BUG")
}
}

xDef, yDef := m.c.ValueDefinition(x), m.c.ValueDefinition(y)
rn := m.getOperand_Mem_Reg(yDef)
rm := m.getOperand_Reg(xDef)
rd := m.c.VRegOf(instr.Return())

// rm is being overwritten, so we first copy its value to a temp register,
// in case it is referenced again later.
tmp := m.copyToTmp(rm.r)

xmm := m.allocateInstr().asXmmRmR(op, rn, tmp)
m.insert(xmm)

m.copyTo(tmp, rd)
}

func (m *machine) lowerSqrt(instr *ssa.Instruction) {
x := instr.Arg()
if !x.Type().IsFloat() {
panic("BUG")
}
_64 := x.Type().Bits() == 64
var op sseOpcode
if _64 {
op = sseOpcodeSqrtsd
} else {
op = sseOpcodeSqrtss
}

xDef := m.c.ValueDefinition(x)
rm := m.getOperand_Mem_Reg(xDef)
rd := m.c.VRegOf(instr.Return())

xmm := m.allocateInstr().asXmmUnaryRmR(op, rm, rd)
m.insert(xmm)
}

func (m *machine) lowerFabsFneg(instr *ssa.Instruction) {
x := instr.Arg()
if !x.Type().IsFloat() {
panic("BUG")
}
_64 := x.Type().Bits() == 64
var op sseOpcode
var mask uint64
if _64 {
switch instr.Opcode() {
case ssa.OpcodeFabs:
mask, op = 0x7fffffffffffffff, sseOpcodeAndpd
case ssa.OpcodeFneg:
mask, op = 0x8000000000000000, sseOpcodeXorpd
}
} else {
switch instr.Opcode() {
case ssa.OpcodeFabs:
mask, op = 0x7fffffff, sseOpcodeAndps
case ssa.OpcodeFneg:
mask, op = 0x80000000, sseOpcodeXorps
}
}

tmp := m.c.AllocateVReg(x.Type())

xDef := m.c.ValueDefinition(x)
rm := m.getOperand_Reg(xDef)
rd := m.c.VRegOf(instr.Return())

m.lowerFconst(tmp, mask, _64)

xmm := m.allocateInstr().asXmmRmR(op, rm, tmp)
m.insert(xmm)

m.copyTo(tmp, rd)
}

func (m *machine) lowerStore(si *ssa.Instruction) {
value, ptr, offset, storeSizeInBits := si.StoreData()
rm := m.getOperand_Reg(m.c.ValueDefinition(value))
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/wazevo/backend/isa/amd64/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ L2:
}
}

func Test_machine_lowerCtz(t *testing.T) {
func TestMachine_lowerCtz(t *testing.T) {
for _, tc := range []struct {
name string
setup func(*mockCompiler, ssa.Builder, *machine) *backend.SSAValueDefinition
Expand Down
29 changes: 29 additions & 0 deletions internal/engine/wazevo/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ func TestE2E(t *testing.T) {
},
},
},
{
name: "float_arithm", m: testcases.FloatArithm.Module,
calls: []callCase{
{
params: []uint64{
math.Float64bits(25), math.Float64bits(5), uint64(math.Float32bits(25)), uint64(math.Float32bits(5)),
},
expResults: []uint64{
math.Float64bits(-25),
math.Float64bits(25),

math.Float64bits(5),
math.Float64bits(30),
math.Float64bits(20),
math.Float64bits(125),
math.Float64bits(5),

uint64(math.Float32bits(-25)),
uint64(math.Float32bits(25)),

uint64(math.Float32bits(5)),
uint64(math.Float32bits(30)),
uint64(math.Float32bits(20)),
uint64(math.Float32bits(125)),
uint64(math.Float32bits(5)),
},
},
},
},
{
name: "fibonacci_recursive", m: testcases.FibonacciRecursive.Module,
calls: []callCase{
Expand Down
5 changes: 5 additions & 0 deletions internal/engine/wazevo/ssa/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func (t Type) IsInt() bool {
return t == TypeI32 || t == TypeI64
}

// IsFloat returns true if the type is a floating point type.
func (t Type) IsFloat() bool {
return t == TypeF32 || t == TypeF64
}

// Bits returns the number of bits required to represent the type.
func (t Type) Bits() byte {
switch t {
Expand Down
64 changes: 63 additions & 1 deletion internal/engine/wazevo/testcases/testcases.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ var (
}, nil),
}
ArithmReturn = TestCase{
Name: "add_sub_params_return",
Name: "arithm return",
Module: SingleFunctionModule(
wasm.FunctionType{
Params: []wasm.ValueType{i32, i32, i32, i64, i64, i64},
Expand Down Expand Up @@ -989,6 +989,68 @@ var (
wasm.OpcodeEnd,
}, []wasm.ValueType{}),
}
FloatArithm = TestCase{
Name: "float_arithm",
Module: SingleFunctionModule(wasm.FunctionType{
Params: []wasm.ValueType{f64, f64, f32, f32},
Results: []wasm.ValueType{f64, f64, f64, f64, f64, f64, f64, f32, f32, f32, f32, f32, f32, f32},
}, []byte{
wasm.OpcodeLocalGet, 0,
wasm.OpcodeF64Neg,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeF64Neg,
wasm.OpcodeF64Abs,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeF64Sqrt,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeLocalGet, 1,
wasm.OpcodeF64Add,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeLocalGet, 1,
wasm.OpcodeF64Sub,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeLocalGet, 1,
wasm.OpcodeF64Mul,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeLocalGet, 1,
wasm.OpcodeF64Div,

// 32-bit floats.
wasm.OpcodeLocalGet, 2,
wasm.OpcodeF32Neg,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeF32Neg,
wasm.OpcodeF32Abs,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeF32Sqrt,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeLocalGet, 3,
wasm.OpcodeF32Add,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeLocalGet, 3,
wasm.OpcodeF32Sub,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeLocalGet, 3,
wasm.OpcodeF32Mul,

wasm.OpcodeLocalGet, 2,
wasm.OpcodeLocalGet, 3,
wasm.OpcodeF32Div,

wasm.OpcodeEnd,
}, []wasm.ValueType{}),
}
FloatConversions = TestCase{
Name: "float_conversions",
Module: SingleFunctionModule(wasm.FunctionType{
Expand Down

0 comments on commit bc5917a

Please sign in to comment.