Skip to content

Commit

Permalink
Add WithEnableSIGTERM option (#457)
Browse files Browse the repository at this point in the history
* WIP adaption of my scratch code

* implement as a handler Option

* fix linter errors

* io -> ioutil to keep older CI running

* add test case, attempt to configure github action to install runtime interface emulator

* fix truncated url

* -L

* please the race detector

* contsrain testcase to go 1.15+

* please the linter

* -v the tests

* add test variant that checks that sigterm isn't enabled by default

* Update tests.yml
  • Loading branch information
bmoffatt committed Jul 28, 2022
1 parent 8bc331d commit 93199f7
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 1 deletion.
6 changes: 5 additions & 1 deletion .github/workflows/tests.yml
Expand Up @@ -25,11 +25,15 @@ jobs:

- run: go version

- name: install lambda runtime interface emulator
run: curl -L -o /usr/local/bin/aws-lambda-rie https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-x86_64
- run: chmod +x /usr/local/bin/aws-lambda-rie

- name: Check out code into the Go module directory
uses: actions/checkout@v2

- name: go test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
90 changes: 90 additions & 0 deletions lambda/extensions_api_client.go
@@ -0,0 +1,90 @@
package lambda

import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
)

const (
headerExtensionName = "Lambda-Extension-Name"
headerExtensionIdentifier = "Lambda-Extension-Identifier"
extensionAPIVersion = "2020-01-01"
)

type extensionAPIEventType string

const (
extensionInvokeEvent extensionAPIEventType = "INVOKE" //nolint:deadcode,unused,varcheck
extensionShutdownEvent extensionAPIEventType = "SHUTDOWN" //nolint:deadcode,unused,varcheck
)

type extensionAPIClient struct {
baseURL string
httpClient *http.Client
}

func newExtensionAPIClient(address string) *extensionAPIClient {
client := &http.Client{
Timeout: 0, // connections to the extensions API are never expected to time out
}
endpoint := "http://" + address + "/" + extensionAPIVersion + "/extension/"
return &extensionAPIClient{
baseURL: endpoint,
httpClient: client,
}
}

func (c *extensionAPIClient) register(name string, events ...extensionAPIEventType) (string, error) {
url := c.baseURL + "register"
body, _ := json.Marshal(struct {
Events []extensionAPIEventType `json:"events"`
}{
Events: events,
})

req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
req.Header.Add(headerExtensionName, name)
res, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to register extension: %v", err)
}
defer res.Body.Close()
_, _ = io.Copy(ioutil.Discard, res.Body)

if res.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
}

return res.Header.Get(headerExtensionIdentifier), nil
}

type extensionEventResponse struct {
EventType extensionAPIEventType
// ... the rest not implemented
}

func (c *extensionAPIClient) next(id string) (response extensionEventResponse, err error) {
url := c.baseURL + "event/next"

req, _ := http.NewRequest(http.MethodGet, url, nil)
req.Header.Add(headerExtensionIdentifier, id)
res, err := c.httpClient.Do(req)
if err != nil {
err = fmt.Errorf("failed to get extension event: %v", err)
return
}
defer res.Body.Close()
_, _ = io.Copy(ioutil.Discard, res.Body)

if res.StatusCode != http.StatusOK {
err = fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
return
}

err = json.NewDecoder(res.Body).Decode(&response)
return
}
25 changes: 25 additions & 0 deletions lambda/handler.go
Expand Up @@ -23,6 +23,8 @@ type handlerOptions struct {
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
jsonResponseIndentValue string
enableSIGTERM bool
sigtermCallbacks []func()
}

type Option func(*handlerOptions)
Expand Down Expand Up @@ -73,6 +75,26 @@ func WithSetIndent(prefix, indent string) Option {
})
}

// WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown.
// SIGKILL will occur ~500ms after SIGTERM.
// Optionally, an array of callback functions to run on SIGTERM may be provided.
//
// Usage:
// lambda.StartWithOptions(
// func (event any) (any error) {
// return event, nil
// },
// lambda.WithEnableSIGTERM(func() {
// log.Print("function container shutting down...")
// })
// )
func WithEnableSIGTERM(callbacks ...func()) Option {
return Option(func(h *handlerOptions) {
h.sigtermCallbacks = append(h.sigtermCallbacks, callbacks...)
h.enableSIGTERM = true
})
}

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
Expand Down Expand Up @@ -139,6 +161,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
for _, option := range options {
option(h)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
h.Handler = reflectHandler(handlerFunc, h)
return h
}
Expand Down
53 changes: 53 additions & 0 deletions lambda/sigterm.go
@@ -0,0 +1,53 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.

package lambda

import (
"log"
"os"
"os/signal"
"syscall"
)

// enableSIGTERM configures an optional list of sigtermHandlers to run on process shutdown.
// This non-default behavior is enabled within Lambda using the extensions API.
func enableSIGTERM(sigtermHandlers []func()) {
// for fun, we'll also optionally register SIGTERM handlers
if len(sigtermHandlers) > 0 {
signaled := make(chan os.Signal, 1)
signal.Notify(signaled, syscall.SIGTERM)
go func() {
<-signaled
for _, f := range sigtermHandlers {
f()
}
}()
}

// detect if we're actually running within Lambda
endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API")
if endpoint == "" {
log.Print("WARNING! AWS_LAMBDA_RUNTIME_API environment variable not found. Skipping attempt to register internal extension...")
return
}

// Now to do the AWS Lambda specific stuff.
// The default Lambda behavior is for functions to get SIGKILL at the end of lifetime, or after a timeout.
// Any use of the Lambda extension register API enables SIGTERM to be sent to the function process before the SIGKILL.
// We'll register an extension that does not listen for any lifecycle events named "GoLangEnableSIGTERM".
// The API will respond with an ID we need to pass in future requests.
client := newExtensionAPIClient(endpoint)
id, err := client.register("GoLangEnableSIGTERM")
if err != nil {
log.Printf("WARNING! Failed to register internal extension! SIGTERM events may not be enabled! err: %v", err)
return
}

// We didn't actually register for any events, but we need to call /next anyways to let the API know we're done initalizing.
// Because we didn't register for any events, /next will never return, so we'll do this in a go routine that is doomed to stay blocked.
go func() {
_, err := client.next(id)
log.Printf("WARNING! Reached expected unreachable code! Extension /next call expected to block forever! err: %v", err)
}()

}
93 changes: 93 additions & 0 deletions lambda/sigterm_test.go
@@ -0,0 +1,93 @@
//go:build go1.15
// +build go1.15

package lambda

import (
"io/ioutil"
"net/http"
"os"
"os/exec"
"path"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
)

func TestEnableSigterm(t *testing.T) {
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
}

testDir := t.TempDir()

// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "sigterm.handler"), "./testdata/sigterm.go")
handlerBuild.Stderr = os.Stderr
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

for name, opts := range map[string]struct {
envVars []string
assertLogs func(t *testing.T, logs string)
}{
"baseline": {
assertLogs: func(t *testing.T, logs string) {
assert.NotContains(t, logs, "Hello SIGTERM!")
assert.NotContains(t, logs, "I've been TERMINATED!")
},
},
"sigterm enabled": {
envVars: []string{"ENABLE_SIGTERM=please"},
assertLogs: func(t *testing.T, logs string) {
assert.Contains(t, logs, "Hello SIGTERM!")
assert.Contains(t, logs, "I've been TERMINATED!")
},
},
} {
t.Run(name, func(t *testing.T) {
// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
cmd.Env = append([]string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
}, opts.envVars...)
cmd.Stderr = os.Stderr
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
var logs string
done := make(chan interface{}) // closed on completion of log flush
go func() {
logBytes, err := ioutil.ReadAll(stdout)
require.NoError(t, err)
logs = string(logBytes)
close(done)
}()
require.NoError(t, cmd.Start())
t.Cleanup(func() { _ = cmd.Process.Kill() })

// give a moment for the port to bind
time.Sleep(500 * time.Millisecond)

client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}"))
require.NoError(t, err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, string(body), "Task timed out after 2.00 seconds")

require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained
<-done
t.Logf("stdout:\n%s", logs)
opts.assertLogs(t, logs)
})
}
}
42 changes: 42 additions & 0 deletions lambda/testdata/sigterm.go
@@ -0,0 +1,42 @@
package main

import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"

"github.com/aws/aws-lambda-go/lambda"
)

func init() {
// conventional SIGTERM callback
signaled := make(chan os.Signal, 1)
signal.Notify(signaled, syscall.SIGTERM)
go func() {
<-signaled
fmt.Println("I've been TERMINATED!")
}()

}

func main() {
// lambda option to enable sigterm, plus optional extra sigterm callbacks
sigtermOption := lambda.WithEnableSIGTERM(func() {
fmt.Println("Hello SIGTERM!")
})
handlerOptions := []lambda.Option{}
if os.Getenv("ENABLE_SIGTERM") != "" {
handlerOptions = append(handlerOptions, sigtermOption)
}
lambda.StartWithOptions(
func(ctx context.Context) {
deadline, _ := ctx.Deadline()
<-time.After(time.Until(deadline) + time.Second)
panic("unreachable line reached!")
},
handlerOptions...,
)
}

0 comments on commit 93199f7

Please sign in to comment.