Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support the lambda.norpc tag on the go1.x runtime #456

Merged
merged 6 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 1 addition & 11 deletions lambda/entry.go
Expand Up @@ -4,7 +4,6 @@ package lambda

import (
"context"
"errors"
"log"
"os"
)
Expand Down Expand Up @@ -70,20 +69,11 @@ type startFunction struct {
}

var (
// This allows users to save a little bit of coldstart time in the download, by the dependencies brought in for RPC support.
// The tradeoff is dropping compatibility with the go1.x runtime, functions must be "Custom Runtime" instead.
// To drop the rpc dependencies, compile with `-tags lambda.norpc`
rpcStartFunction = &startFunction{
env: "_LAMBDA_SERVER_PORT",
f: func(_ string, _ Handler) error {
return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support")
},
}
runtimeAPIStartFunction = &startFunction{
env: "AWS_LAMBDA_RUNTIME_API",
f: startRuntimeAPILoop,
}
startFunctions = []*startFunction{rpcStartFunction, runtimeAPIStartFunction}
startFunctions = []*startFunction{runtimeAPIStartFunction}

// This allows end to end testing of the Start functions, by tests overwriting this function to keep the program alive
logFatalf = log.Fatalf
Expand Down
59 changes: 0 additions & 59 deletions lambda/entry_test.go
Expand Up @@ -4,15 +4,11 @@ package lambda

import (
"context"
"fmt"
"log"
"net"
"net/rpc"
"os"
"strings"
"testing"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
)

Expand All @@ -35,58 +31,3 @@ func TestStartRuntimeAPIWithContext(t *testing.T) {

assert.Equal(t, expected, actual)
}

func TestStartRPCWithContext(t *testing.T) {
expected := "expected"
actual := "unexpected"
port := getFreeTCPPort()
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
actual, _ = ctx.Value(ctxTestKey{}).(string)
return nil
})

var client *rpc.Client
var pingResponse messages.PingResponse
var invokeResponse messages.InvokeResponse
var err error
for {
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
continue
}
break
}
for {
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
continue
}
break
}
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
t.Logf("error invoking function: %v", err)
}

assert.Equal(t, expected, actual)
}

func getFreeTCPPort() int {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatal("getFreeTCPPort failed: ", err)
}
defer l.Close()

return l.Addr().(*net.TCPAddr).Port
}

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
23 changes: 23 additions & 0 deletions lambda/entry_with_no_rpc_test.go
@@ -0,0 +1,23 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.

//go:build lambda.norpc
// +build lambda.norpc

package lambda

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
74 changes: 74 additions & 0 deletions lambda/entry_with_rpc_test.go
@@ -0,0 +1,74 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved

//go:build !lambda.norpc
// +build !lambda.norpc

package lambda

import (
"context"
"fmt"
"log"
"net"
"net/rpc"
"os"
"testing"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
)

func TestStartRPCWithContext(t *testing.T) {
expected := "expected"
actual := "unexpected"
port := getFreeTCPPort()
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error {
actual, _ = ctx.Value(ctxTestKey{}).(string)
return nil
})

var client *rpc.Client
var pingResponse messages.PingResponse
var invokeResponse messages.InvokeResponse
var err error
for {
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
continue
}
break
}
for {
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
continue
}
break
}
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
t.Logf("error invoking function: %v", err)
}

assert.Equal(t, expected, actual)
}

func getFreeTCPPort() int {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatal("getFreeTCPPort failed: ", err)
}
defer l.Close()

return l.Addr().(*net.TCPAddr).Port
}

func TestStartNotInLambda(t *testing.T) {
actual := "unexpected"
logFatalf = func(format string, v ...interface{}) {
actual = fmt.Sprintf(format, v...)
}

Start(func() error { return nil })
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
}
119 changes: 78 additions & 41 deletions lambda/invoke_loop.go
Expand Up @@ -3,101 +3,138 @@
package lambda

import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strconv"
"time"

"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/aws/aws-lambda-go/lambdacontext"
)

const (
msPerS = int64(time.Second / time.Millisecond)
nsPerMS = int64(time.Millisecond / time.Nanosecond)
)

// TODO: replace with time.UnixMillis after dropping version <1.17 from CI workflows
func unixMS(ms int64) time.Time {
return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS)
}

// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
func startRuntimeAPILoop(api string, handler Handler) error {
client := newRuntimeAPIClient(api)
function := NewFunction(handler)
h := newHandler(handler)
for {
invoke, err := client.next()
if err != nil {
return err
}

err = handleInvoke(invoke, function)
if err != nil {
if err = handleInvoke(invoke, h); err != nil {
return err
}
}
}

// handleInvoke returns an error if the function panics, or some other non-recoverable error occurred
func handleInvoke(invoke *invoke, function *Function) error {
functionRequest, err := convertInvokeRequest(invoke)
func handleInvoke(invoke *invoke, handler *handlerOptions) error {
// set the deadline
deadline, err := parseDeadline(invoke)
if err != nil {
return fmt.Errorf("unexpected error occurred when parsing the invoke: %v", err)
return reportFailure(invoke, lambdaErrorResponse(err))
}
ctx, cancel := context.WithDeadline(handler.baseContext, deadline)
defer cancel()

functionResponse := &messages.InvokeResponse{}
if err := function.Invoke(functionRequest, functionResponse); err != nil {
return fmt.Errorf("unexpected error occurred when invoking the handler: %v", err)
// set the invoke metadata values
lc := lambdacontext.LambdaContext{
AwsRequestID: invoke.id,
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
}

if functionResponse.Error != nil {
errorPayload := safeMarshal(functionResponse.Error)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
if err := parseClientContext(invoke, &lc.ClientContext); err != nil {
return reportFailure(invoke, lambdaErrorResponse(err))
}
if err := parseCognitoIdentity(invoke, &lc.Identity); err != nil {
return reportFailure(invoke, lambdaErrorResponse(err))
}
ctx = lambdacontext.NewContext(ctx, &lc)

// set the trace id
traceID := invoke.headers.Get(headerTraceID)
os.Setenv("_X_AMZN_TRACE_ID", traceID)
// nolint:staticcheck
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)

// call the handler, marshal any returned error
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
if invokeErr != nil {
if err := reportFailure(invoke, invokeErr); err != nil {
return err
}
if functionResponse.Error.ShouldExit {
if invokeErr.ShouldExit {
return fmt.Errorf("calling the handler function resulted in a panic, the process should exit")
}
return nil
}

if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil {
if err := invoke.success(response, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
}

return nil
}

// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest.
func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) {
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS)
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
errorPayload := safeMarshal(invokeErr)
log.Printf("%s", errorPayload)
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
}
deadlineS := deadlineEpochMS / msPerS
deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS
return nil
}

res := &messages.InvokeRequest{
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
XAmznTraceId: invoke.headers.Get(headerTraceID),
RequestId: invoke.id,
Deadline: messages.InvokeRequest_Timestamp{
Seconds: deadlineS,
Nanos: deadlineNS,
},
Payload: invoke.payload,
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
defer func() {
if err := recover(); err != nil {
invokeErr = lambdaPanicResponse(err)
}
}()
response, err := handler(ctx, payload)
if err != nil {
return nil, lambdaErrorResponse(err)
}
return response, nil
}

clientContextJSON := invoke.headers.Get(headerClientContext)
if clientContextJSON != "" {
res.ClientContext = []byte(clientContextJSON)
func parseDeadline(invoke *invoke) (time.Time, error) {
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse deadline: %v", err)
}
return unixMS(deadlineEpochMS), nil
}

func parseCognitoIdentity(invoke *invoke, out *lambdacontext.CognitoIdentity) error {
cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity)
if cognitoIdentityJSON != "" {
if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil {
return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
if err := json.Unmarshal([]byte(cognitoIdentityJSON), out); err != nil {
return fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
}
}
return nil
}

return res, nil
func parseClientContext(invoke *invoke, out *lambdacontext.ClientContext) error {
clientContextJSON := invoke.headers.Get(headerClientContext)
if clientContextJSON != "" {
if err := json.Unmarshal([]byte(clientContextJSON), out); err != nil {
return fmt.Errorf("failed to unmarshal client context json: %v", err)
}
}
return nil
}

func safeMarshal(v interface{}) []byte {
Expand Down