From 067260897c2485427135041867fa56bbb121eaeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Sat, 9 Mar 2024 05:33:48 +0000 Subject: [PATCH] PUA So this turns out to be less easy than I thought a few hours ago Flatbuffers is pretty bad, the lua side hasn't been touched for 3 years yet I found a silly typo just flipping through the code: https://github.com/google/flatbuffers/pull/8251 In particular, they seem to assume string.pack/string.unpack will exist, but in 5.1 it doesn't, with LuaJIT they use ffi which means that works out for most 5.1 users Solution: we're going to shim our own go implementation of the flatbuffers lua library --- flow/connectors/kafka/kafka.go | 151 ++++++++++++++++------------ flow/pua/flatbuffers.go | 42 ++++++++ flow/pua/flatbuffers_binaryarray.go | 134 ++++++++++++++++++++++++ flow/pua/userdata.go | 32 ++++++ 4 files changed, 294 insertions(+), 65 deletions(-) create mode 100644 flow/pua/flatbuffers.go create mode 100644 flow/pua/flatbuffers_binaryarray.go create mode 100644 flow/pua/userdata.go diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index d3b9930cd..0e092a473 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -2,8 +2,10 @@ package connkafka import ( "context" - "encoding/json" + "errors" "fmt" + "reflect" + "strings" "sync" "github.com/twmb/franz-go/pkg/kgo" @@ -14,6 +16,8 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/pua" ) type KafkaConnector struct { @@ -83,21 +87,54 @@ func (c *KafkaConnector) SyncFlowCleanup(ctx context.Context, jobName string) er return c.pgMetadata.DropMetadata(ctx, jobName) } -type LuaUserDataType[T any] struct{ Name string } +var ( + LuaRecord = pua.LuaUserDataType[model.Record]{Name: "peerdb_row"} + LuaQValue = pua.LuaUserDataType[qvalue.QValue]{Name: "peerdb_qvalue"} +) + +func RegisterTypes(ls *lua.LState) { + mt := ls.NewTypeMetatable(LuaRecord.Name) + ls.SetField(mt, "__index", ls.NewFunction(LuaRecordIndex)) + + mt = ls.NewTypeMetatable(LuaQValue.Name) + ls.SetField(mt, "__index", ls.NewFunction(LuaQValueIndex)) + ls.SetField(mt, "__len", ls.NewFunction(LuaQValueLen)) +} -var LuaRecord = LuaUserDataType[model.Record]{Name: "peerdb_row"} +func LuaRecordIndex(ls *lua.LState) int { + record, key := LuaRecord.StartIndex(ls) -func RegisterTypes(L *lua.LState) { - mt := L.NewTypeMetatable(LuaRecord.Name) - L.SetField(mt, "__index", LuaRecordGet) + recordItems := record.GetItems() + qv, err := recordItems.GetValueByColName(key) + if err != nil { + ls.Error(lua.LString(err.Error()), 0) + } + + ls.Push(LuaQValue.New(ls, qv)) + return 1 } -func (udt *LuaUserDataType[T]) Construct(ls *lua.LState, val T) *lua.LUserData { - return &lua.LUserData{ - Value: val, - Env: ls.Env, - Metatable: ls.GetTypeMetatable(udt.Name), +func LuaQValueIndex(ls *lua.LState) int { + qv, key := LuaQValue.StartIndex(ls) + if key == "kind" { + ls.Push(lua.LString(qv.Kind)) + return 1 } + return 0 +} + +func LuaQValueLen(ls *lua.LState) int { + qv := LuaQValue.StartMeta(ls) + str, ok := qv.Value.(string) + if ok { + ls.Push(lua.LNumber(len(str))) + return 1 + } + if strings.HasPrefix(string(qv.Kind), "array_") { + ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Len())) + return 1 + } + return 0 } func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { @@ -119,20 +156,42 @@ func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecords tableNameRowsMapping := make(map[string]uint32) var fn *lua.LFunction - var state *lua.LState + var ls *lua.LState if req.Script != "" { - state = lua.NewState(lua.Options{SkipOpenLibs: true}) - defer state.Close() - state.SetContext(wgCtx) - state.DoString(req.Script) + ls = lua.NewState(lua.Options{SkipOpenLibs: true}) + defer ls.Close() + ls.SetContext(wgCtx) + for _, pair := range []struct { + n string + f lua.LGFunction + }{ + {lua.LoadLibName, lua.OpenPackage}, // Must be first + {lua.BaseLibName, lua.OpenBase}, + {lua.TabLibName, lua.OpenTable}, + {lua.StringLibName, lua.OpenString}, + {lua.MathLibName, lua.OpenMath}, + } { + ls.Push(ls.NewFunction(pair.f)) + ls.Push(lua.LString(pair.n)) + err := ls.PCall(1, 0, nil) + if err != nil { + return nil, fmt.Errorf("failed to initialize Lua runtime: %w", err) + } + } + ls.PreloadModule("flatbuffers", pua.FlatBuffers_Loader) + ls.PreloadModule("flatbuffers.binaryarray", pua.FlatBuffers_BinaryArray_Loader) + err := ls.DoString(req.Script) + if err != nil { + return nil, fmt.Errorf("error while executing script: %w", err) + } var ok bool - fn, ok = state.GetGlobal("onRow").(*lua.LFunction) + fn, ok = ls.GetGlobal("onRow").(*lua.LFunction) if !ok { - return nil, fmt.Errorf("Script should define `onRow` function") + return nil, errors.New("script should define `onRow` function") } } else { - return nil, fmt.Errorf("Kafka mirror must have script") + return nil, errors.New("kafka mirror must have script") } for record := range req.Records.GetRecords() { @@ -140,53 +199,15 @@ func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecords return nil, err } topic := record.GetDestinationTableName() - switch typedRecord := record.(type) { - case *model.InsertRecord: - insertData := KafkaRecord{ - Old: nil, - New: typedRecord.Items.Values, - } - - state.Push(KafkaRecord) - state.Call() - lfn(insertData) - insertJSON, err := json.Marshal(insertData) - if err != nil { - return nil, fmt.Errorf("failed to serialize insert data to JSON: %w", err) - } - - wg.Add(1) - c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: insertJSON}, produceCb) - case *model.UpdateRecord: - updateData := KafkaRecord{ - Old: typedRecord.OldItems.Values, - New: typedRecord.NewItems.Values, - } - updateJSON, err := json.Marshal(updateData) - if err != nil { - return nil, fmt.Errorf("failed to serialize update data to JSON: %w", err) - } - - wg.Add(1) - c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: updateJSON}, produceCb) - case *model.DeleteRecord: - deleteData := KafkaRecord{ - Old: typedRecord.Items.Values, - New: nil, - } - deleteJSON, err := json.Marshal(deleteData) - if err != nil { - return nil, fmt.Errorf("failed to serialize delete data to JSON: %w", err) - } - - wg.Add(1) - c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: deleteJSON}, produceCb) - default: - // TODO ignore - unknownErr := fmt.Errorf("record type %T not supported in Kafka flow connector", typedRecord) - wgErr(unknownErr) - return nil, unknownErr + ls.Push(fn) + ls.Push(LuaRecord.New(ls, record)) + err := ls.PCall(1, 1, nil) + if err != nil { + return nil, fmt.Errorf("script failed: %w", err) } + value := ls.CheckString(-1) + wg.Add(1) + c.client.Produce(wgCtx, &kgo.Record{Topic: topic, Value: []byte(value)}, produceCb) numRecords += 1 tableNameRowsMapping[topic] += 1 diff --git a/flow/pua/flatbuffers.go b/flow/pua/flatbuffers.go new file mode 100644 index 000000000..ca7d32de2 --- /dev/null +++ b/flow/pua/flatbuffers.go @@ -0,0 +1,42 @@ +package pua + +import ( + "github.com/yuin/gopher-lua" +) + +/* +local m = {} + +m.Builder = require("flatbuffers.builder").New +m.N = require("flatbuffers.numTypes") +m.view = require("flatbuffers.view") +m.binaryArray = require("flatbuffers.binaryarray") + +return m +*/ + +func requireHelper(ls *lua.LState, m *lua.LTable, require lua.LValue, name string, path string) { + ls.Push(require) + ls.Push(lua.LString(path)) + ls.Call(1, 1) + ls.SetField(m, name, ls.Get(-1)) + ls.Pop(1) +} + +func FlatBuffers_Loader(ls *lua.LState) int { + m := ls.NewTable() + require := ls.GetGlobal("require") + ls.Push(require) + ls.Push(lua.LString("flatbuffers.builder")) + ls.Call(1, 1) + builder := ls.GetTable(ls.Get(-1), lua.LString("New")) + ls.SetField(m, "builder", builder) + ls.Pop(1) + + requireHelper(ls, m, require, "N", "flatbuffers.numTypes") + requireHelper(ls, m, require, "view", "flatbuffers.view") + requireHelper(ls, m, require, "binaryArray", "flatbuffers.binaryarray") + + ls.Push(m) + return 1 +} diff --git a/flow/pua/flatbuffers_binaryarray.go b/flow/pua/flatbuffers_binaryarray.go new file mode 100644 index 000000000..8b9018b26 --- /dev/null +++ b/flow/pua/flatbuffers_binaryarray.go @@ -0,0 +1,134 @@ +package pua + +import ( + "github.com/yuin/gopher-lua" +) + +type BinaryArray struct { + data []byte +} + +var LuaBinaryArray = LuaUserDataType[BinaryArray]{Name: "flatbuffers_binaryarray"} + +func FlatBuffers_BinaryArray_Loader(ls *lua.LState) int { + m := ls.NewTable() + ls.SetField(m, "New", ls.NewFunction(BinaryArrayNew)) + ls.SetField(m, "Pack", ls.NewFunction(BinaryArrayPack)) + ls.SetField(m, "Unpack", ls.NewFunction(BinaryArrayUnpack)) + + mt := ls.NewTable() + ls.SetField(mt, "__index", ls.NewFunction(BinaryArrayIndex)) + ls.SetField(mt, "__len", ls.NewFunction(BinaryArrayLen)) + ls.SetField(mt, "Slice", ls.NewFunction(BinaryArraySlice)) + ls.SetField(mt, "Grow", ls.NewFunction(BinaryArrayGrow)) + ls.SetField(mt, "Pad", ls.NewFunction(BinaryArrayPad)) + ls.SetField(mt, "Set", ls.NewFunction(BinaryArraySet)) + + ls.Push(m) + return 1 +} + +func BinaryArrayNew(ls *lua.LState) int { + lval := ls.Get(-1) + var ba BinaryArray + switch val := lval.(type) { + case lua.LString: + ba = BinaryArray{ + data: []byte(val), + } + case lua.LNumber: + ba = BinaryArray{ + data: make([]byte, int(val)), + } + default: + ls.Error(lua.LString("Expect a integer size value or string to construct a binary array"), 0) + return 0 + } + ls.Push(LuaBinaryArray.New(ls, ba)) + return 1 +} + +func BinaryArrayLen(ls *lua.LState) int { + ba := LuaBinaryArray.StartMeta(ls) + ls.Push(lua.LNumber(len(ba.data))) + return 1 +} + +func BinaryArrayIndex(ls *lua.LState) int { + ba, key := LuaBinaryArray.StartIndex(ls) + switch key { + case "size": + ls.Push(lua.LNumber(len(ba.data))) + case "str": + ls.Push(lua.LString(ba.data)) + case "data": + ls.Error(lua.LString("binaryArray data property inaccessible"), 0) + return 0 + default: + ls.Push(ls.GetField(LuaBinaryArray.Metatable(ls), key)) + } + return 1 +} + +func BinaryArraySlice(ls *lua.LState) int { + var startPos, endPos int + ba := LuaBinaryArray.StartMeta(ls) + if luaStartPos, ok := ls.Get(2).(lua.LNumber); ok { + startPos = max(int(luaStartPos), 0) + } else { + startPos = 0 + } + if luaEndPos, ok := ls.Get(3).(lua.LNumber); ok { + endPos = min(int(luaEndPos), len(ba.data)) + } else { + endPos = len(ba.data) + } + ls.Push(lua.LString(ba.data[startPos:endPos])) + return 1 +} + +func BinaryArrayGrow(ls *lua.LState) int { + ba := LuaBinaryArray.StartMeta(ls) + newsize := int(ls.CheckNumber(2)) + if newsize > len(ba.data) { + newdata := make([]byte, newsize) + copy(newdata[newsize-len(ba.data):], ba.data) + ba.data = newdata + } + return 0 +} + +func BinaryArrayPad(ls *lua.LState) int { + ba := LuaBinaryArray.StartMeta(ls) + n := int(ls.CheckNumber(2)) + startPos := int(ls.CheckNumber(3)) + for i := range n { + ba.data[startPos+i] = 0 + } + return 0 +} + +func BinaryArraySet(ls *lua.LState) int { + ba := LuaBinaryArray.StartMeta(ls) + idx := int(ls.CheckNumber(3)) + value := ls.Get(2) + if num, ok := value.(lua.LNumber); ok { + ba.data[idx] = byte(num) + } + if str, ok := value.(lua.LString); ok { + ba.data[idx] = str[0] + } + return 0 +} + +// (fmt, ...) +// return string.pack(fmt, ...) +func BinaryArrayPack(ls *lua.LState) int { + panic("TODO") +} + +// fmt, s, pos +// return string.unpack(fmt, ba.data, pos+1) +func BinaryArrayUnpack(ls *lua.LState) int { + panic("TODO") +} diff --git a/flow/pua/userdata.go b/flow/pua/userdata.go new file mode 100644 index 000000000..3676f1356 --- /dev/null +++ b/flow/pua/userdata.go @@ -0,0 +1,32 @@ +package pua + +import ( + "github.com/yuin/gopher-lua" +) + +type LuaUserDataType[T any] struct{ Name string } + +func (udt *LuaUserDataType[T]) New(ls *lua.LState, val T) *lua.LUserData { + return &lua.LUserData{ + Value: val, + Env: ls.Env, + Metatable: udt.Metatable(ls), + } +} + +func (udt *LuaUserDataType[T]) Metatable(ls *lua.LState) lua.LValue { + return ls.GetTypeMetatable(udt.Name) +} + +func (udt *LuaUserDataType[T]) StartMeta(ls *lua.LState) T { + lrecord := ls.CheckUserData(1) + val, ok := lrecord.Value.(T) + if !ok { + ls.Error(lua.LString("Invalid "+udt.Name), 0) + } + return val +} + +func (udt *LuaUserDataType[T]) StartIndex(ls *lua.LState) (T, string) { + return udt.StartMeta(ls), ls.CheckString(2) +}