Skip to content

Commit

Permalink
support the lambda.norpc tag on the go1.x runtime (#456)
Browse files Browse the repository at this point in the history
* support lambda.norpc tag on go1.x and rewrite invoke_loop.go to remove use of Function RPC type

* restore pre-1.17 unix ms conversions

* more test coverage

* add quick benchmark comparision to show throughput improvement of norpc build flag

* more

* update bench.sh
  • Loading branch information
bmoffatt committed Jul 27, 2022
1 parent d8bb932 commit 8bc331d
Show file tree
Hide file tree
Showing 13 changed files with 414 additions and 178 deletions.
12 changes: 1 addition & 11 deletions lambda/entry.go
Expand Up @@ -4,7 +4,6 @@ package lambda

import (
"context"
"errors"
"log"
"os"
)
Expand Down Expand Up @@ -70,20 +69,11 @@ type startFunction struct {
}

var (
// This allows users to save a little bit of coldstart time in the download, by the dependencies brought in for RPC support.
// The tradeoff is dropping compatibility with the go1.x runtime, functions must be "Custom Runtime" instead.
// To drop the rpc dependencies, compile with `-tags lambda.norpc`
rpcStartFunction = &startFunction{
env: "_LAMBDA_SERVER_PORT",
f: func(_ string, _ Handler) error {
return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support")
},
}
runtimeAPIStartFunction = &startFunction{
env: "AWS_LAMBDA_RUNTIME_API",
f: startRuntimeAPILoop,
}
startFunctions = []*startFunction{rpcStartFunction, runtimeAPIStartFunction}
startFunctions = []*startFunction{runtimeAPIStartFunction}

// This allows end to end testing of the Start functions, by tests overwriting this function to keep the program alive
logFatalf = log.Fatalf
Expand Down
59 changes: 0 additions & 59 deletions lambda/entry_test.go
Expand Up @@ -4,15 +4,11 @@ package lambda

import (
"context"
"fmt"
"log"
"net"
"net/rpc"
"os"
"strings"
"testing"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
)

Expand All @@ -35,58 +31,3 @@ func TestStartRuntimeAPIWithContext(t *testing.T) {

assert.Equal(t, expected, actual)
}

func TestStartRPCWithContext(t *testing.T) {
expected := "expected"
actual := "unexpected"
port := getFreeTCPPort()
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
actual, _ = ctx.Value(ctxTestKey{}).(string)
return nil
})

var client *rpc.Client
var pingResponse messages.PingResponse
var invokeResponse messages.InvokeResponse
var err error
for {
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
continue
}
break
}
for {
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
continue
}
break
}
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
t.Logf("error invoking function: %v", err)
}

assert.Equal(t, expected, actual)
}

func getFreeTCPPort() int {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatal("getFreeTCPPort failed: ", err)
}
defer l.Close()

return l.Addr().(*net.TCPAddr).Port
}

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
23 changes: 23 additions & 0 deletions lambda/entry_with_no_rpc_test.go
@@ -0,0 +1,23 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.

//go:build lambda.norpc
// +build lambda.norpc

package lambda

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
74 changes: 74 additions & 0 deletions lambda/entry_with_rpc_test.go
@@ -0,0 +1,74 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved

//go:build !lambda.norpc
// +build !lambda.norpc

package lambda

import (
"context"
"fmt"
"log"
"net"
"net/rpc"
"os"
"testing"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
)

func TestStartRPCWithContext(t *testing.T) {
expected := "expected"
actual := "unexpected"
port := getFreeTCPPort()
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
actual, _ = ctx.Value(ctxTestKey{}).(string)
return nil
})

var client *rpc.Client
var pingResponse messages.PingResponse
var invokeResponse messages.InvokeResponse
var err error
for {
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
continue
}
break
}
for {
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
continue
}
break
}
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
t.Logf("error invoking function: %v", err)
}

assert.Equal(t, expected, actual)
}

func getFreeTCPPort() int {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatal("getFreeTCPPort failed: ", err)
}
defer l.Close()

return l.Addr().(*net.TCPAddr).Port
}

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
119 changes: 78 additions & 41 deletions lambda/invoke_loop.go
Expand Up @@ -3,101 +3,138 @@
package lambda

import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strconv"
"time"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/aws/aws-lambda-go/lambdacontext"
)

const (
msPerS = int64(time.Second / time.Millisecond)
nsPerMS = int64(time.Millisecond / time.Nanosecond)
)

// TODO: replace with time.UnixMillis after dropping version <1.17 from CI workflows
func unixMS(ms int64) time.Time {
return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS)
}

// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
func startRuntimeAPILoop(api string, handler Handler) error {
client := newRuntimeAPIClient(api)
function := NewFunction(handler)
h := newHandler(handler)
for {
invoke, err := client.next()
if err != nil {
return err
}

err = handleInvoke(invoke, function)
if err != nil {
if err = handleInvoke(invoke, h); err != nil {
return err
}
}
}

// handleInvoke returns an error if the function panics, or some other non-recoverable error occurred
func handleInvoke(invoke *invoke, function *Function) error {
functionRequest, err := convertInvokeRequest(invoke)
func handleInvoke(invoke *invoke, handler *handlerOptions) error {
// set the deadline
deadline, err := parseDeadline(invoke)
if err != nil {
return fmt.Errorf("unexpected error occurred when parsing the invoke: %v", err)
return reportFailure(invoke, lambdaErrorResponse(err))
}
ctx, cancel := context.WithDeadline(handler.baseContext, deadline)
defer cancel()

functionResponse := &messages.InvokeResponse{}
if err := function.Invoke(functionRequest, functionResponse); err != nil {
return fmt.Errorf("unexpected error occurred when invoking the handler: %v", err)
// set the invoke metadata values
lc := lambdacontext.LambdaContext{
AwsRequestID: invoke.id,
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
}

if functionResponse.Error != nil {
errorPayload := safeMarshal(functionResponse.Error)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
if err := parseClientContext(invoke, &lc.ClientContext); err != nil {
return reportFailure(invoke, lambdaErrorResponse(err))
}
if err := parseCognitoIdentity(invoke, &lc.Identity); err != nil {
return reportFailure(invoke, lambdaErrorResponse(err))
}
ctx = lambdacontext.NewContext(ctx, &lc)

// set the trace id
traceID := invoke.headers.Get(headerTraceID)
os.Setenv("_X_AMZN_TRACE_ID", traceID)
// nolint:staticcheck
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)

// call the handler, marshal any returned error
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
if invokeErr != nil {
if err := reportFailure(invoke, invokeErr); err != nil {
return err
}
if functionResponse.Error.ShouldExit {
if invokeErr.ShouldExit {
return fmt.Errorf("calling the handler function resulted in a panic, the process should exit")
}
return nil
}

if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil {
if err := invoke.success(response, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
}

return nil
}

// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest.
func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) {
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS)
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
errorPayload := safeMarshal(invokeErr)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
}
deadlineS := deadlineEpochMS / msPerS
deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS
return nil
}

res := &messages.InvokeRequest{
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
XAmznTraceId: invoke.headers.Get(headerTraceID),
RequestId: invoke.id,
Deadline: messages.InvokeRequest_Timestamp{
Seconds: deadlineS,
Nanos: deadlineNS,
},
Payload: invoke.payload,
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
defer func() {
if err := recover(); err != nil {
invokeErr = lambdaPanicResponse(err)
}
}()
response, err := handler(ctx, payload)
if err != nil {
return nil, lambdaErrorResponse(err)
}
return response, nil
}

clientContextJSON := invoke.headers.Get(headerClientContext)
if clientContextJSON != "" {
res.ClientContext = []byte(clientContextJSON)
func parseDeadline(invoke *invoke) (time.Time, error) {
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse deadline: %v", err)
}
return unixMS(deadlineEpochMS), nil
}

func parseCognitoIdentity(invoke *invoke, out *lambdacontext.CognitoIdentity) error {
cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity)
if cognitoIdentityJSON != "" {
if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil {
return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
if err := json.Unmarshal([]byte(cognitoIdentityJSON), out); err != nil {
return fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
}
}
return nil
}

return res, nil
func parseClientContext(invoke *invoke, out *lambdacontext.ClientContext) error {
clientContextJSON := invoke.headers.Get(headerClientContext)
if clientContextJSON != "" {
if err := json.Unmarshal([]byte(clientContextJSON), out); err != nil {
return fmt.Errorf("failed to unmarshal client context json: %v", err)
}
}
return nil
}

func safeMarshal(v interface{}) []byte {
Expand Down

0 comments on commit 8bc331d

Please sign in to comment.