Skip to content

Commit

Permalink
eventstream: adds middleware to close input-writer in case of error (#…
Browse files Browse the repository at this point in the history
…3867)

* eventream template change: add unmarshal error handler to close inputWriter io.writer in case of an error for the first connection request

* make generate changes

* hand-written unit test to test for error case io.write resource is freed

* remove dead code

* bug fix: set original error for validate response handler error

* use awserr error type for any error occuring when closing input writer for eventstream

* use ioutil.NopCloser instead of io.NopCloser to satisfy older go versions

* use anonymous creds for test

* eventstream: update template for code gen

* update test and generate service clients
  • Loading branch information
skotambkar committed Apr 27, 2021
1 parent e8bafb8 commit d265de1
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 55 deletions.
2 changes: 1 addition & 1 deletion aws/corehandlers/handlers.go
Expand Up @@ -178,7 +178,7 @@ func handleSendError(r *request.Request, err error) {
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
r.Error = awserr.New("UnknownError", "unknown error", r.Error)
}
}}

Expand Down
12 changes: 10 additions & 2 deletions private/model/api/eventstream_tmpl.go
Expand Up @@ -214,6 +214,14 @@ func (es *{{ $esapi.Name }}) waitStreamPartClose() {
es.inputWriter = inputWriter
}
// Closes the input-pipe writer
func (es *{{ $esapi.Name }}) closeInputPipe() error {
if es.inputWriter != nil {
return es.inputWriter.Close()
}
return nil
}
// Send writes the event to the stream blocking until the event is written.
// Returns an error if the event was not written.
//
Expand Down Expand Up @@ -400,8 +408,8 @@ func (es *{{ $esapi.Name }}) safeClose() {
case <-t.C:
case <-writeCloseDone:
}
if es.inputWriter != nil {
es.inputWriter.Close()
if err := es.closeInputPipe(); err != nil {
es.err.SetError(err)
}
{{- end }}
Expand Down
9 changes: 9 additions & 0 deletions private/model/api/operation.go
Expand Up @@ -248,6 +248,15 @@ func (c *{{ .API.StructName }}) {{ .ExportedName }}Request(` +
{{- if $inputStream }}
req.Handlers.Sign.PushFront(es.setupInputPipe)
req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{
Name: "InputPipeCloser",
Fn: func (r *request.Request) {
err := es.closeInputPipe()
if err != nil {
r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error)
}
},
})
req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{
"Content-Type": "application/vnd.amazon.eventstream",
"X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS",
Expand Down
4 changes: 4 additions & 0 deletions private/protocol/eventstream/eventstreamapi/error.go
Expand Up @@ -5,6 +5,10 @@ import (
"sync"
)

// InputWriterCloseErrorCode is used to denote an error occurred
// while closing the event stream input writer.
const InputWriterCloseErrorCode = "EventStreamInputWriterCloseError"

type messageError struct {
code string
msg string
Expand Down
46 changes: 0 additions & 46 deletions private/protocol/eventstream/eventstreamapi/writer.go
Expand Up @@ -61,49 +61,3 @@ func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
return msg, nil
}

//type EventEncoder struct {
// encoder Encoder
// ppayloadMarshaler protocol.PayloadMarshaler
// eventTypeFor func(Marshaler) (string, error)
//}
//
//func (e EventEncoder) Encode(event Marshaler) error {
// msg, err := e.marshal(event)
// if err != nil {
// return err
// }
//
// return w.encoder.Encode(msg)
//}
//
//func (e EventEncoder) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//
//func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//
21 changes: 19 additions & 2 deletions service/lexruntimev2/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 38 additions & 4 deletions service/transcribestreamingservice/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 69 additions & 0 deletions service/transcribestreamingservice/unit_test.go
@@ -0,0 +1,69 @@
// +build go1.10

package transcribestreamingservice

import (
"bytes"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
)

type roundTripFunc func(req *http.Request) *http.Response

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}

func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}

func TestStartStreamTranscription_Error(t *testing.T) {
cfg := &aws.Config{
Region: aws.String("us-west-2"),
Credentials: credentials.AnonymousCredentials,
HTTPClient: newTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(bytes.NewReader([]byte("{ \"code\" : \"BadRequestException\" }"))),
Header: http.Header{},
}
}),
}
sess, err := session.NewSession(cfg)

svc := New(sess)
resp, err := svc.StartStreamTranscription(&StartStreamTranscriptionInput{
LanguageCode: aws.String(LanguageCodeEnUs),
MediaEncoding: aws.String(MediaEncodingPcm),
MediaSampleRateHertz: aws.Int64(int64(16000)),
})
if err == nil {
t.Fatalf("expect error, got none")
} else {
if e, a := "BadRequestException", err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to be %v, got %v", e, a)
}
}

n, err := resp.GetStream().inputWriter.Write([]byte("text"))
if err == nil {
t.Fatalf("expected error stating write on closed pipe, got none")
}

if e, a := "write on closed pipe", err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to contain %v, got error as %v", e, a)
}

if e, a := 0, n; e != a {
t.Fatalf("expected %d bytes to be written on inputWriter, but %v bytes were written", e, a)
}
}

0 comments on commit d265de1

Please sign in to comment.