Skip to content

Commit

Permalink
PUA
Browse files Browse the repository at this point in the history
So this turns out to be less easy than I thought a few hours ago

Flatbuffers is pretty bad, the lua side hasn't been touched for 3 years yet I found a silly typo just flipping through the code:
google/flatbuffers#8251

In particular, they seem to assume string.pack/string.unpack will exist, but in 5.1 it doesn't,
with LuaJIT they use ffi which means that works out for most 5.1 users

Solution: we're going to shim our own go implementation of the flatbuffers lua library
  • Loading branch information
serprex committed Mar 9, 2024
1 parent 2dacc7a commit 0672608
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 65 deletions.
151 changes: 86 additions & 65 deletions flow/connectors/kafka/kafka.go
Expand Up @@ -2,8 +2,10 @@ package connkafka

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"sync"

"github.com/twmb/franz-go/pkg/kgo"
Expand All @@ -14,6 +16,8 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/logger"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/pua"
)

type KafkaConnector struct {
Expand Down Expand Up @@ -83,21 +87,54 @@ func (c *KafkaConnector) SyncFlowCleanup(ctx context.Context, jobName string) er
return c.pgMetadata.DropMetadata(ctx, jobName)
}

type LuaUserDataType[T any] struct{ Name string }
var (
LuaRecord = pua.LuaUserDataType[model.Record]{Name: "peerdb_row"}
LuaQValue = pua.LuaUserDataType[qvalue.QValue]{Name: "peerdb_qvalue"}
)

func RegisterTypes(ls *lua.LState) {
mt := ls.NewTypeMetatable(LuaRecord.Name)
ls.SetField(mt, "__index", ls.NewFunction(LuaRecordIndex))

mt = ls.NewTypeMetatable(LuaQValue.Name)
ls.SetField(mt, "__index", ls.NewFunction(LuaQValueIndex))
ls.SetField(mt, "__len", ls.NewFunction(LuaQValueLen))
}

var LuaRecord = LuaUserDataType[model.Record]{Name: "peerdb_row"}
func LuaRecordIndex(ls *lua.LState) int {
record, key := LuaRecord.StartIndex(ls)

func RegisterTypes(L *lua.LState) {
mt := L.NewTypeMetatable(LuaRecord.Name)
L.SetField(mt, "__index", LuaRecordGet)
recordItems := record.GetItems()
qv, err := recordItems.GetValueByColName(key)
if err != nil {
ls.Error(lua.LString(err.Error()), 0)
}

ls.Push(LuaQValue.New(ls, qv))
return 1
}

func (udt *LuaUserDataType[T]) Construct(ls *lua.LState, val T) *lua.LUserData {
return &lua.LUserData{
Value: val,
Env: ls.Env,
Metatable: ls.GetTypeMetatable(udt.Name),
func LuaQValueIndex(ls *lua.LState) int {
qv, key := LuaQValue.StartIndex(ls)
if key == "kind" {
ls.Push(lua.LString(qv.Kind))
return 1
}
return 0
}

func LuaQValueLen(ls *lua.LState) int {
qv := LuaQValue.StartMeta(ls)
str, ok := qv.Value.(string)
if ok {
ls.Push(lua.LNumber(len(str)))
return 1
}
if strings.HasPrefix(string(qv.Kind), "array_") {
ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Len()))
return 1
}
return 0
}

func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) {
Expand All @@ -119,74 +156,58 @@ func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecords
tableNameRowsMapping := make(map[string]uint32)

var fn *lua.LFunction
var state *lua.LState
var ls *lua.LState
if req.Script != "" {
state = lua.NewState(lua.Options{SkipOpenLibs: true})
defer state.Close()
state.SetContext(wgCtx)
state.DoString(req.Script)
ls = lua.NewState(lua.Options{SkipOpenLibs: true})
defer ls.Close()
ls.SetContext(wgCtx)
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage}, // Must be first
{lua.BaseLibName, lua.OpenBase},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
} {
ls.Push(ls.NewFunction(pair.f))
ls.Push(lua.LString(pair.n))
err := ls.PCall(1, 0, nil)
if err != nil {
return nil, fmt.Errorf("failed to initialize Lua runtime: %w", err)
}
}
ls.PreloadModule("flatbuffers", pua.FlatBuffers_Loader)
ls.PreloadModule("flatbuffers.binaryarray", pua.FlatBuffers_BinaryArray_Loader)
err := ls.DoString(req.Script)
if err != nil {
return nil, fmt.Errorf("error while executing script: %w", err)
}

var ok bool
fn, ok = state.GetGlobal("onRow").(*lua.LFunction)
fn, ok = ls.GetGlobal("onRow").(*lua.LFunction)
if !ok {
return nil, fmt.Errorf("Script should define `onRow` function")
return nil, errors.New("script should define `onRow` function")
}
} else {
return nil, fmt.Errorf("Kafka mirror must have script")
return nil, errors.New("kafka mirror must have script")
}

for record := range req.Records.GetRecords() {
if err := wgCtx.Err(); err != nil {
return nil, err
}
topic := record.GetDestinationTableName()
switch typedRecord := record.(type) {
case *model.InsertRecord:
insertData := KafkaRecord{
Old: nil,
New: typedRecord.Items.Values,
}

state.Push(KafkaRecord)
state.Call()
lfn(insertData)
insertJSON, err := json.Marshal(insertData)
if err != nil {
return nil, fmt.Errorf("failed to serialize insert data to JSON: %w", err)
}

wg.Add(1)
c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: insertJSON}, produceCb)
case *model.UpdateRecord:
updateData := KafkaRecord{
Old: typedRecord.OldItems.Values,
New: typedRecord.NewItems.Values,
}
updateJSON, err := json.Marshal(updateData)
if err != nil {
return nil, fmt.Errorf("failed to serialize update data to JSON: %w", err)
}

wg.Add(1)
c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: updateJSON}, produceCb)
case *model.DeleteRecord:
deleteData := KafkaRecord{
Old: typedRecord.Items.Values,
New: nil,
}
deleteJSON, err := json.Marshal(deleteData)
if err != nil {
return nil, fmt.Errorf("failed to serialize delete data to JSON: %w", err)
}

wg.Add(1)
c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: deleteJSON}, produceCb)
default:
// TODO ignore
unknownErr := fmt.Errorf("record type %T not supported in Kafka flow connector", typedRecord)
wgErr(unknownErr)
return nil, unknownErr
ls.Push(fn)
ls.Push(LuaRecord.New(ls, record))
err := ls.PCall(1, 1, nil)
if err != nil {
return nil, fmt.Errorf("script failed: %w", err)
}
value := ls.CheckString(-1)
wg.Add(1)
c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: []byte(value)}, produceCb)

numRecords += 1
tableNameRowsMapping[topic] += 1
Expand Down
42 changes: 42 additions & 0 deletions flow/pua/flatbuffers.go
@@ -0,0 +1,42 @@
package pua

import (
"github.com/yuin/gopher-lua"
)

/*
local m = {}
m.Builder = require("flatbuffers.builder").New
m.N = require("flatbuffers.numTypes")
m.view = require("flatbuffers.view")
m.binaryArray = require("flatbuffers.binaryarray")
return m
*/

func requireHelper(ls *lua.LState, m *lua.LTable, require lua.LValue, name string, path string) {
ls.Push(require)
ls.Push(lua.LString(path))
ls.Call(1, 1)
ls.SetField(m, name, ls.Get(-1))
ls.Pop(1)
}

func FlatBuffers_Loader(ls *lua.LState) int {
m := ls.NewTable()
require := ls.GetGlobal("require")
ls.Push(require)
ls.Push(lua.LString("flatbuffers.builder"))
ls.Call(1, 1)
builder := ls.GetTable(ls.Get(-1), lua.LString("New"))
ls.SetField(m, "builder", builder)
ls.Pop(1)

requireHelper(ls, m, require, "N", "flatbuffers.numTypes")
requireHelper(ls, m, require, "view", "flatbuffers.view")
requireHelper(ls, m, require, "binaryArray", "flatbuffers.binaryarray")

ls.Push(m)
return 1
}
134 changes: 134 additions & 0 deletions flow/pua/flatbuffers_binaryarray.go
@@ -0,0 +1,134 @@
package pua

import (
"github.com/yuin/gopher-lua"
)

type BinaryArray struct {
data []byte
}

var LuaBinaryArray = LuaUserDataType[BinaryArray]{Name: "flatbuffers_binaryarray"}

func FlatBuffers_BinaryArray_Loader(ls *lua.LState) int {
m := ls.NewTable()
ls.SetField(m, "New", ls.NewFunction(BinaryArrayNew))
ls.SetField(m, "Pack", ls.NewFunction(BinaryArrayPack))
ls.SetField(m, "Unpack", ls.NewFunction(BinaryArrayUnpack))

mt := ls.NewTable()
ls.SetField(mt, "__index", ls.NewFunction(BinaryArrayIndex))
ls.SetField(mt, "__len", ls.NewFunction(BinaryArrayLen))
ls.SetField(mt, "Slice", ls.NewFunction(BinaryArraySlice))
ls.SetField(mt, "Grow", ls.NewFunction(BinaryArrayGrow))
ls.SetField(mt, "Pad", ls.NewFunction(BinaryArrayPad))
ls.SetField(mt, "Set", ls.NewFunction(BinaryArraySet))

ls.Push(m)
return 1
}

func BinaryArrayNew(ls *lua.LState) int {
lval := ls.Get(-1)
var ba BinaryArray
switch val := lval.(type) {
case lua.LString:
ba = BinaryArray{
data: []byte(val),
}
case lua.LNumber:
ba = BinaryArray{
data: make([]byte, int(val)),
}
default:
ls.Error(lua.LString("Expect a integer size value or string to construct a binary array"), 0)
return 0
}
ls.Push(LuaBinaryArray.New(ls, ba))
return 1
}

func BinaryArrayLen(ls *lua.LState) int {
ba := LuaBinaryArray.StartMeta(ls)
ls.Push(lua.LNumber(len(ba.data)))
return 1
}

func BinaryArrayIndex(ls *lua.LState) int {
ba, key := LuaBinaryArray.StartIndex(ls)
switch key {
case "size":
ls.Push(lua.LNumber(len(ba.data)))
case "str":
ls.Push(lua.LString(ba.data))
case "data":
ls.Error(lua.LString("binaryArray data property inaccessible"), 0)
return 0
default:
ls.Push(ls.GetField(LuaBinaryArray.Metatable(ls), key))
}
return 1
}

func BinaryArraySlice(ls *lua.LState) int {
var startPos, endPos int
ba := LuaBinaryArray.StartMeta(ls)
if luaStartPos, ok := ls.Get(2).(lua.LNumber); ok {
startPos = max(int(luaStartPos), 0)
} else {
startPos = 0
}
if luaEndPos, ok := ls.Get(3).(lua.LNumber); ok {
endPos = min(int(luaEndPos), len(ba.data))
} else {
endPos = len(ba.data)
}
ls.Push(lua.LString(ba.data[startPos:endPos]))
return 1
}

func BinaryArrayGrow(ls *lua.LState) int {
ba := LuaBinaryArray.StartMeta(ls)
newsize := int(ls.CheckNumber(2))
if newsize > len(ba.data) {
newdata := make([]byte, newsize)
copy(newdata[newsize-len(ba.data):], ba.data)
ba.data = newdata
}
return 0
}

func BinaryArrayPad(ls *lua.LState) int {
ba := LuaBinaryArray.StartMeta(ls)
n := int(ls.CheckNumber(2))
startPos := int(ls.CheckNumber(3))
for i := range n {
ba.data[startPos+i] = 0
}
return 0
}

func BinaryArraySet(ls *lua.LState) int {
ba := LuaBinaryArray.StartMeta(ls)
idx := int(ls.CheckNumber(3))
value := ls.Get(2)
if num, ok := value.(lua.LNumber); ok {
ba.data[idx] = byte(num)
}
if str, ok := value.(lua.LString); ok {
ba.data[idx] = str[0]
}
return 0
}

// (fmt, ...)
// return string.pack(fmt, ...)
func BinaryArrayPack(ls *lua.LState) int {
panic("TODO")
}

// fmt, s, pos
// return string.unpack(fmt, ba.data, pos+1)
func BinaryArrayUnpack(ls *lua.LState) int {
panic("TODO")
}

0 comments on commit 0672608

Please sign in to comment.