Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eventstream: adds middleware to close input-writer in case of error #3867

Merged
merged 10 commits into from Apr 27, 2021
2 changes: 1 addition & 1 deletion aws/corehandlers/handlers.go
Expand Up @@ -178,7 +178,7 @@ func handleSendError(r *request.Request, err error) {
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
r.Error = awserr.New("UnknownError", "unknown error", r.Error)
}
}}

Expand Down
12 changes: 10 additions & 2 deletions private/model/api/eventstream_tmpl.go
Expand Up @@ -214,6 +214,14 @@ func (es *{{ $esapi.Name }}) waitStreamPartClose() {
es.inputWriter = inputWriter
}

// Closes the input-pipe writer
func (es *{{ $esapi.Name }}) closeInputPipe() error {
if es.inputWriter != nil {
return es.inputWriter.Close()
}
return nil
}

// Send writes the event to the stream blocking until the event is written.
// Returns an error if the event was not written.
//
Expand Down Expand Up @@ -400,8 +408,8 @@ func (es *{{ $esapi.Name }}) safeClose() {
case <-t.C:
case <-writeCloseDone:
}
if es.inputWriter != nil {
es.inputWriter.Close()
if err := es.closeInputPipe(); err != nil {
es.err.SetError(err)
}
{{- end }}

Expand Down
9 changes: 9 additions & 0 deletions private/model/api/operation.go
Expand Up @@ -248,6 +248,15 @@ func (c *{{ .API.StructName }}) {{ .ExportedName }}Request(` +
{{- if $inputStream }}

req.Handlers.Sign.PushFront(es.setupInputPipe)
req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{
Name: "InputPipeCloser",
Fn: func (r *request.Request) {
err := es.closeInputPipe()
if err != nil {
r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error)
}
},
})
req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{
"Content-Type": "application/vnd.amazon.eventstream",
"X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS",
Expand Down
4 changes: 4 additions & 0 deletions private/protocol/eventstream/eventstreamapi/error.go
Expand Up @@ -5,6 +5,10 @@ import (
"sync"
)

// InputWriterCloseErrorCode is used to denote an error occurred
// while closing the event stream input writer.
const InputWriterCloseErrorCode = "EventStreamInputWriterCloseError"

type messageError struct {
code string
msg string
Expand Down
46 changes: 0 additions & 46 deletions private/protocol/eventstream/eventstreamapi/writer.go
Expand Up @@ -61,49 +61,3 @@ func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
return msg, nil
}

//type EventEncoder struct {
// encoder Encoder
// ppayloadMarshaler protocol.PayloadMarshaler
// eventTypeFor func(Marshaler) (string, error)
//}
//
//func (e EventEncoder) Encode(event Marshaler) error {
// msg, err := e.marshal(event)
// if err != nil {
// return err
// }
//
// return w.encoder.Encode(msg)
//}
//
//func (e EventEncoder) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//
//func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//
21 changes: 19 additions & 2 deletions service/lexruntimev2/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 38 additions & 4 deletions service/transcribestreamingservice/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 69 additions & 0 deletions service/transcribestreamingservice/unit_test.go
@@ -0,0 +1,69 @@
// +build go1.10

package transcribestreamingservice

import (
"bytes"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
)

type roundTripFunc func(req *http.Request) *http.Response

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}

func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}

func TestStartStreamTranscription_Error(t *testing.T) {
cfg := &aws.Config{
Region: aws.String("us-west-2"),
Credentials: credentials.AnonymousCredentials,
HTTPClient: newTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(bytes.NewReader([]byte("{ \"code\" : \"BadRequestException\" }"))),
Header: http.Header{},
}
}),
}
sess, err := session.NewSession(cfg)

svc := New(sess)
resp, err := svc.StartStreamTranscription(&StartStreamTranscriptionInput{
LanguageCode: aws.String(LanguageCodeEnUs),
MediaEncoding: aws.String(MediaEncodingPcm),
MediaSampleRateHertz: aws.Int64(int64(16000)),
})
if err == nil {
t.Fatalf("expect error, got none")
} else {
if e, a := "BadRequestException", err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to be %v, got %v", e, a)
}
}

n, err := resp.GetStream().inputWriter.Write([]byte("text"))
if err == nil {
t.Fatalf("expected error stating write on closed pipe, got none")
}

if e, a := "write on closed pipe", err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to contain %v, got error as %v", e, a)
}

if e, a := 0, n; e != a {
t.Fatalf("expected %d bytes to be written on inputWriter, but %v bytes were written", e, a)
}
}