Skip to content

Commit

Permalink
schema: Allow handlers to be in separate files (#1074)
Browse files Browse the repository at this point in the history
Allow handler functions for schema fields to be imported from different
files. Previously we were referencing handler functions by the function
name.

Now we hold onto the Starlark reference so we can always find the
function, even if it's imported from a different file.
  • Loading branch information
rohansingh committed May 15, 2024
1 parent 86c4032 commit 1d1639c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 32 deletions.
8 changes: 4 additions & 4 deletions schema/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

type Generated struct {
SchemaField
starlarkHandler *starlark.Function
}

func newGenerated(
Expand All @@ -35,7 +34,7 @@ func newGenerated(
}

s := &Generated{}
s.starlarkHandler = handler
s.StarlarkHandler = handler
s.Source = source.GoString()
s.Handler = handler.Name()
s.ID = id.GoString()
Expand All @@ -56,14 +55,15 @@ func (s *Generated) AttrNames() []string {

func (s *Generated) Attr(name string) (starlark.Value, error) {
switch name {

case "source":
return starlark.String(s.Source), nil

case "handler":
return s.starlarkHandler, nil
return s.StarlarkHandler, nil

case "id":
return starlark.String(s.ID), nil

default:
return nil, nil
}
Expand Down
5 changes: 2 additions & 3 deletions schema/locationbased.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

type LocationBased struct {
SchemaField
starlarkHandler *starlark.Function
}

func newLocationBased(
Expand Down Expand Up @@ -45,7 +44,7 @@ func newLocationBased(
s.Description = desc.GoString()
s.Icon = icon.GoString()
s.Handler = handler.Name()
s.starlarkHandler = handler
s.StarlarkHandler = handler

return s, nil
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func (s *LocationBased) Attr(name string) (starlark.Value, error) {
return starlark.String(s.Icon), nil

case "handler":
return s.starlarkHandler, nil
return s.StarlarkHandler, nil

default:
return nil, nil
Expand Down
7 changes: 3 additions & 4 deletions schema/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (

type OAuth2 struct {
SchemaField
starlarkHandler *starlark.Function
starlarkScopes *starlark.List
starlarkScopes *starlark.List
}

func newOAuth2(
Expand Down Expand Up @@ -52,7 +51,7 @@ func newOAuth2(
s.Description = desc.GoString()
s.Icon = icon.GoString()
s.Handler = handler.Name()
s.starlarkHandler = handler
s.StarlarkHandler = handler
s.ClientID = clientID.GoString()
s.AuthorizationEndpoint = authEndpoint.GoString()
s.starlarkScopes = scopes
Expand Down Expand Up @@ -109,7 +108,7 @@ func (s *OAuth2) Attr(name string) (starlark.Value, error) {
return starlark.String(s.Icon), nil

case "handler":
return s.starlarkHandler, nil
return s.StarlarkHandler, nil

case "client_id":
return starlark.String(s.ClientID), nil
Expand Down
24 changes: 17 additions & 7 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Schema struct {
// SchemaField represents an item in the config used to confgure an applet.
type SchemaField struct {
Type string `json:"type" validate:"required,oneof=color datetime dropdown generated location locationbased onoff radio text typeahead oauth2 oauth1 png notification"`
ID string `json:"id" validate:"required"`
ID string `json:"id" validate:"required,excludesall=$"`
Name string `json:"name,omitempty" validate:"required_for=datetime dropdown location locationbased onoff radio text typeahead png"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty" validate:"forbidden_for=generated"`
Expand All @@ -47,8 +47,9 @@ type SchemaField struct {
Palette []string `json:"palette,omitempty"`
Sounds []SchemaSound `json:"sounds,omitempty" validate:"required_for=notification,dive"`

Source string `json:"source,omitempty" validate:"required_for=generated"`
Handler string `json:"handler,omitempty" validate:"required_for=generated locationbased typeahead oauth2"`
Source string `json:"source,omitempty" validate:"required_for=generated"`
Handler string `json:"handler,omitempty" validate:"required_for=generated locationbased typeahead oauth2"`
StarlarkHandler *starlark.Function `json:"-"`

ClientID string `json:"client_id,omitempty" validate:"required_for=oauth2"`
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty" validate:"required_for=oauth2"`
Expand Down Expand Up @@ -131,6 +132,8 @@ func FromStarlark(
}
}
} else {
// this is a legacy path, where the schema was just a dict
// instead of a StarlarkSchema object
schemaTree, err := unmarshalStarlark(val)
if err != nil {
return nil, err
Expand All @@ -156,22 +159,27 @@ func FromStarlark(
}

for i, schemaField := range schema.Fields {
if schemaField.Handler != "" {
handlerValue, found := globals[schemaField.Handler]
if !found {
var handlerFun *starlark.Function
if schemaField.StarlarkHandler != nil {
handlerFun = schemaField.StarlarkHandler
} else if schemaField.Handler != "" {
handlerValue, ok := globals[schemaField.Handler]
if !ok {
return nil, fmt.Errorf(
"field %d references non-existent handler \"%s\"",
i,
schemaField.Handler)
}

handlerFun, ok := handlerValue.(*starlark.Function)
handlerFun, ok = handlerValue.(*starlark.Function)
if !ok {
return nil, fmt.Errorf(
"field %d references \"%s\" which is not a function",
i, schemaField.Handler)
}
}

if handlerFun != nil {
var handlerType HandlerReturnType
switch schemaField.Type {
case "locationbased":
Expand All @@ -190,6 +198,8 @@ func FromStarlark(
i, schemaField.Type)
}

// prepend the field ID to the handler name to avoid conflicts
schemaField.Handler = fmt.Sprintf("%s$%s", schemaField.ID, schemaField.Handler)
schema.Handlers[schemaField.Handler] = SchemaHandler{Function: handlerFun, ReturnType: handlerType}
}
}
Expand Down
77 changes: 66 additions & 11 deletions schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def main():
assert.Error(t, err)

// They're identified by function name
jsonSchema, err := app.CallSchemaHandler(context.Background(), "generate_schema", "foobar")
jsonSchema, err := app.CallSchemaHandler(context.Background(), "generatedid$generate_schema", "foobar")
assert.NoError(t, err)

var s schema.Schema
Expand Down Expand Up @@ -712,10 +712,10 @@ def main():
app, err := loadApp(code)
assert.NoError(t, err)

_, err = app.CallSchemaHandler(context.Background(), "generate_schema", "win")
_, err = app.CallSchemaHandler(context.Background(), "generatedid$generate_schema", "win")
assert.NoError(t, err)

_, err = app.CallSchemaHandler(context.Background(), "generate_schema", "fail")
_, err = app.CallSchemaHandler(context.Background(), "generatedid$generate_schema", "fail")
assert.Error(t, err)
}

Expand Down Expand Up @@ -801,7 +801,7 @@ def main():
app, err := loadApp(code)
assert.NoError(t, err)

stringValue, err := app.CallSchemaHandler(context.Background(), "handle_location", "fart")
stringValue, err := app.CallSchemaHandler(context.Background(), "locationbasedid$handle_location", "fart")
assert.NoError(t, err)
assert.Equal(t, "[{\"display\":\"\",\"text\":\"Your only option is\",\"value\":\"fart\"}]", stringValue)
}
Expand Down Expand Up @@ -856,7 +856,7 @@ def main():
app, err := loadApp(code)
assert.NoError(t, err)

stringValue, err := app.CallSchemaHandler(context.Background(), "handle_typeahead", "farts")
stringValue, err := app.CallSchemaHandler(context.Background(), "typeaheadid$handle_typeahead", "farts")
assert.NoError(t, err)
assert.Equal(t, "[{\"display\":\"\",\"text\":\"You searched for\",\"value\":\"farts\"}]", stringValue)
}
Expand Down Expand Up @@ -915,7 +915,7 @@ def main():
app, err := loadApp(code)
assert.NoError(t, err)

stringValue, err := app.CallSchemaHandler(context.Background(), "oauth2handler", "farts")
stringValue, err := app.CallSchemaHandler(context.Background(), "oauth2id$oauth2handler", "farts")
assert.NoError(t, err)
assert.Equal(t, "a-refresh-token", stringValue)
}
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def get_schema():
icon = "football",
),
schema.Generated(
id = "generated shouldnt need id, but still does",
id = "generatedid",
source = "with_borough",
handler = build_boroughs,
),
Expand All @@ -1065,14 +1065,14 @@ def main():
assert.NoError(t, err)
assert.NotNil(t, app)

data, err := app.CallSchemaHandler(context.Background(), "build_boroughs", "false")
data, err := app.CallSchemaHandler(context.Background(), "generatedid$build_boroughs", "false")
assert.NoError(t, err)
var schema schema.Schema
assert.NoError(t, json.Unmarshal([]byte(data), &schema))
assert.Equal(t, "1", schema.Version)
assert.Equal(t, 0, len(schema.Fields))

data, err = app.CallSchemaHandler(context.Background(), "build_boroughs", "true")
data, err = app.CallSchemaHandler(context.Background(), "generatedid$build_boroughs", "true")
assert.NoError(t, err)
assert.NoError(t, json.Unmarshal([]byte(data), &schema))
assert.Equal(t, 1, len(schema.Fields))
Expand Down Expand Up @@ -1113,7 +1113,7 @@ def get_schema():
icon = "football",
),
schema.Generated(
id = "generated shouldnt need id, but still does",
id = "generatedid",
source = "select_station",
handler = get_station_selector,
),
Expand All @@ -1136,7 +1136,7 @@ def main():
assert.NoError(t, err)
assert.NotNil(t, app)

data, err := app.CallSchemaHandler(context.Background(), "get_station_selector", "true")
data, err := app.CallSchemaHandler(context.Background(), "generatedid$get_station_selector", "true")
assert.NoError(t, err)
var s schema.Schema
assert.NoError(t, json.Unmarshal([]byte(data), &s))
Expand All @@ -1153,3 +1153,58 @@ def main():
assert.Equal(t, "L08", options[0].Value)
assert.Equal(t, "3rd", options[1].Value)
}

func TestSchemaWithHandlerInDifferentFile(t *testing.T) {
handlerFile := `
load("schema.star", "schema")
def get_stations(loc):
return [
schema.Option(display="Bedford (L)", value = "L08"),
schema.Option(display="3rd Ave (L)", value = "3rd"),
]
`

mainFile := `
load("schema.star", "schema")
load("handler.star", "get_stations")
def get_schema():
return schema.Schema(
version = "1",
fields = [
schema.LocationBased(
id = "station",
name = "Station",
desc = "Pick a station!",
icon = "train",
handler = get_stations,
),
],
handlers = [
schema.Handler(
handler = get_stations,
type = schema.HandlerType.Options,
),
]
)
def main():
return None
`

vfs := fstest.MapFS{
"handler.star": &fstest.MapFile{Data: []byte(handlerFile)},
"test.star": &fstest.MapFile{Data: []byte(mainFile)},
}
app, err := runtime.NewAppletFromFS("test", vfs)
require.NoError(t, err)

data, err := app.CallSchemaHandler(context.Background(), "get_stations", "locationdata")
var options []schema.SchemaOption
assert.NoError(t, err)
assert.NoError(t, json.Unmarshal([]byte(data), &options))
assert.Equal(t, 2, len(options))
assert.Equal(t, "L08", options[0].Value)
assert.Equal(t, "3rd", options[1].Value)
}
5 changes: 2 additions & 3 deletions schema/typeahead.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

type Typeahead struct {
SchemaField
starlarkHandler *starlark.Function
}

func newTypeahead(
Expand Down Expand Up @@ -45,7 +44,7 @@ func newTypeahead(
s.Description = desc.GoString()
s.Icon = icon.GoString()
s.Handler = handler.Name()
s.starlarkHandler = handler
s.StarlarkHandler = handler

return s, nil
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func (s *Typeahead) Attr(name string) (starlark.Value, error) {
return starlark.String(s.Icon), nil

case "handler":
return s.starlarkHandler, nil
return s.StarlarkHandler, nil

default:
return nil, nil
Expand Down

0 comments on commit 1d1639c

Please sign in to comment.