Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for type-safe Start* function #468

Merged
merged 10 commits into from Dec 4, 2022
8 changes: 4 additions & 4 deletions lambda/entry.go
Expand Up @@ -17,20 +17,20 @@ import (
// - handler must be a function
// - handler may take between 0 and two arguments.
// - if there are two arguments, the first argument must satisfy the "context.Context" interface.
// - handler may return between 0 and two arguments.
// - if there are two return values, the second argument must be an error.
// - handler may return between 0 and two values.
// - if there are two return values, the second return value must be an error.
// - if there is one return value it must be an error.
//
// Valid function signatures:
//
// func ()
// func (TIn)
// func () error
// func (TIn) error
// func () (TOut, error)
// func (TIn) (TOut, error)
// func (context.Context) error
// func (context.Context, TIn)
// func (context.Context, TIn) error
// func (context.Context) (TOut, error)
// func (context.Context, TIn) (TOut, error)
//
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
Expand Down
23 changes: 23 additions & 0 deletions lambda/entry_generic.go
@@ -0,0 +1,23 @@
//go:build go1.18
// +build go1.18
logandavies181 marked this conversation as resolved.
Show resolved Hide resolved

// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved

package lambda

import (
"context"
)

// HandlerFunc represents a valid input with two arguments and two returns as described by Start
type HandlerFunc[TIn, TOut any] interface {
func(context.Context, TIn) (TOut, error)
}

// StartHandlerFunc is the same as StartWithOptions except that it takes a generic input
// so that the function signature can be validated at compile time.
//
// Currently only the `func (context.Context, TIn) (TOut, error)` variant is supported
func StartHandlerFunc[TIn any, TOut any, H HandlerFunc[TIn, TOut]](handler H, options ...Option) {
start(newHandler(handler, options...))
}
36 changes: 36 additions & 0 deletions lambda/entry_generic_test.go
@@ -0,0 +1,36 @@
//go:build go1.18
// +build go1.18

// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved

package lambda

import (
"context"
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStartHandlerFunc(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

f := func(context.Context, any) (any, error) { return 1, nil }
StartHandlerFunc(f)

assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)

handlerType := reflect.TypeOf(f)

handlerTakesContext, err := validateArguments(handlerType)
assert.NoError(t, err)
assert.True(t, handlerTakesContext)

err = validateReturns(handlerType)
assert.NoError(t, err)
}
25 changes: 25 additions & 0 deletions lambda/handler.go
Expand Up @@ -99,6 +99,31 @@ func WithEnableSIGTERM(callbacks ...func()) Option {
})
}

// ValidateHandlerFunc validates the handler against the criteria for Start and returns an error
// if the criteria are not met
func ValidateHandlerFunc(handlerFunc interface{}) error {
logandavies181 marked this conversation as resolved.
Show resolved Hide resolved
if handlerFunc == nil {
return errors.New("handler is nil")
}

handlerType := reflect.TypeOf(handlerFunc)
if handlerType.Kind() != reflect.Func {
return fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func)
}

_, err := validateArguments(handlerType)
if err != nil {
return err
}

err = validateReturns(handlerType)
if err != nil {
return err
}

return nil
}

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
Expand Down
59 changes: 58 additions & 1 deletion lambda/handler_test.go
Expand Up @@ -75,11 +75,68 @@ func TestInvalidHandlers(t *testing.T) {
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
t.Run(fmt.Sprintf("testCase[%d] %s part 1", i, testCase.name), func(t *testing.T) {
lambdaHandler := NewHandler(testCase.handler)
_, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0))
assert.Equal(t, testCase.expected, err)
})

t.Run(fmt.Sprintf("testCase[%d] %s part 2", i, testCase.name), func(t *testing.T) {
err := ValidateHandlerFunc(testCase.handler)
assert.Equal(t, testCase.expected, err)
})
}
}

func TestValidateHandlerFuncValidHandlers(t *testing.T) {
testCases := []struct {
name string
handler interface{}
}{
{
name: "0 arg 0 return",
handler: func() {},
},
{
name: "0 arg, 1 returns",
handler: func() error { return nil },
},
{
name: "1 arg, 0 returns",
handler: func(any) {},
},
{
name: "1 arg, 1 returns",
handler: func(any) error { return nil },
},
{
name: "0 arg, 2 returns",
handler: func() (any, error) { return 1, nil },
},
{
name: "1 arg, 2 returns",
handler: func(any) (any, error) { return 1, nil },
},
{
name: "2 arg, 0 returns",
handler: func(context.Context, any) {},
},
{
name: "2 arg, 1 returns",
handler: func(context.Context, any) error { return nil },
},
{
name: "2 arg, 2 returns",
handler: func(context.Context, any) (any, error) { return 1, nil },
},
}

for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
err := ValidateHandlerFunc(testCase.handler)
assert.Nil(t, err)
})
}
}

Expand Down