diff --git a/aws/corehandlers/handlers.go b/aws/corehandlers/handlers.go index d95a5eb540..36a915efea 100644 --- a/aws/corehandlers/handlers.go +++ b/aws/corehandlers/handlers.go @@ -178,7 +178,7 @@ func handleSendError(r *request.Request, err error) { var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) { if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { // this may be replaced by an UnmarshalError handler - r.Error = awserr.New("UnknownError", "unknown error", nil) + r.Error = awserr.New("UnknownError", "unknown error", r.Error) } }} diff --git a/private/model/api/eventstream_tmpl.go b/private/model/api/eventstream_tmpl.go index ef72b225f4..1f3c58f311 100644 --- a/private/model/api/eventstream_tmpl.go +++ b/private/model/api/eventstream_tmpl.go @@ -214,6 +214,14 @@ func (es *{{ $esapi.Name }}) waitStreamPartClose() { es.inputWriter = inputWriter } + // Closes the input-pipe writer + func (es *{{ $esapi.Name }}) closeInputPipe() error { + if es.inputWriter != nil { + return es.inputWriter.Close() + } + return nil + } + // Send writes the event to the stream blocking until the event is written. // Returns an error if the event was not written. // @@ -400,8 +408,8 @@ func (es *{{ $esapi.Name }}) safeClose() { case <-t.C: case <-writeCloseDone: } - if es.inputWriter != nil { - es.inputWriter.Close() + if err := es.closeInputPipe(); err != nil { + es.err.SetError(err) } {{- end }} diff --git a/private/model/api/operation.go b/private/model/api/operation.go index 0df3f52b93..4f9cfcbff5 100644 --- a/private/model/api/operation.go +++ b/private/model/api/operation.go @@ -248,6 +248,15 @@ func (c *{{ .API.StructName }}) {{ .ExportedName }}Request(` + {{- if $inputStream }} req.Handlers.Sign.PushFront(es.setupInputPipe) + req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{ + Name: "InputPipeCloser", + Fn: func (r *request.Request) { + err := es.closeInputPipe() + if err != nil { + r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error) + } + }, + }) req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{ "Content-Type": "application/vnd.amazon.eventstream", "X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS", diff --git a/private/protocol/eventstream/eventstreamapi/error.go b/private/protocol/eventstream/eventstreamapi/error.go index 34c2e89d53..0a63340e41 100644 --- a/private/protocol/eventstream/eventstreamapi/error.go +++ b/private/protocol/eventstream/eventstreamapi/error.go @@ -5,6 +5,10 @@ import ( "sync" ) +// InputWriterCloseErrorCode is used to denote an error occurred +// while closing the event stream input writer. +const InputWriterCloseErrorCode = "EventStreamInputWriterCloseError" + type messageError struct { code string msg string diff --git a/private/protocol/eventstream/eventstreamapi/writer.go b/private/protocol/eventstream/eventstreamapi/writer.go index 10a3823dfa..7d7a793528 100644 --- a/private/protocol/eventstream/eventstreamapi/writer.go +++ b/private/protocol/eventstream/eventstreamapi/writer.go @@ -61,49 +61,3 @@ func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) { msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType)) return msg, nil } - -//type EventEncoder struct { -// encoder Encoder -// ppayloadMarshaler protocol.PayloadMarshaler -// eventTypeFor func(Marshaler) (string, error) -//} -// -//func (e EventEncoder) Encode(event Marshaler) error { -// msg, err := e.marshal(event) -// if err != nil { -// return err -// } -// -// return w.encoder.Encode(msg) -//} -// -//func (e EventEncoder) marshal(event Marshaler) (eventstream.Message, error) { -// eventType, err := w.eventTypeFor(event) -// if err != nil { -// return eventstream.Message{}, err -// } -// -// msg, err := event.MarshalEvent(w.payloadMarshaler) -// if err != nil { -// return eventstream.Message{}, err -// } -// -// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType)) -// return msg, nil -//} -// -//func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) { -// eventType, err := w.eventTypeFor(event) -// if err != nil { -// return eventstream.Message{}, err -// } -// -// msg, err := event.MarshalEvent(w.payloadMarshaler) -// if err != nil { -// return eventstream.Message{}, err -// } -// -// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType)) -// return msg, nil -//} -// diff --git a/service/lexruntimev2/api.go b/service/lexruntimev2/api.go index 5193f499ab..85c0c26fa2 100644 --- a/service/lexruntimev2/api.go +++ b/service/lexruntimev2/api.go @@ -552,6 +552,15 @@ func (c *LexRuntimeV2) StartConversationRequest(input *StartConversationInput) ( output.eventStream = es req.Handlers.Sign.PushFront(es.setupInputPipe) + req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{ + Name: "InputPipeCloser", + Fn: func(r *request.Request) { + err := es.closeInputPipe() + if err != nil { + r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error) + } + }, + }) req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{ "Content-Type": "application/vnd.amazon.eventstream", "X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS", @@ -715,6 +724,14 @@ func (es *StartConversationEventStream) setupInputPipe(r *request.Request) { es.inputWriter = inputWriter } +// Closes the input-pipe writer +func (es *StartConversationEventStream) closeInputPipe() error { + if es.inputWriter != nil { + return es.inputWriter.Close() + } + return nil +} + // Send writes the event to the stream blocking until the event is written. // Returns an error if the event was not written. // @@ -838,8 +855,8 @@ func (es *StartConversationEventStream) safeClose() { case <-t.C: case <-writeCloseDone: } - if es.inputWriter != nil { - es.inputWriter.Close() + if err := es.closeInputPipe(); err != nil { + es.err.SetError(err) } es.Reader.Close() diff --git a/service/transcribestreamingservice/api.go b/service/transcribestreamingservice/api.go index 0d9bba80bc..d01b8782d9 100644 --- a/service/transcribestreamingservice/api.go +++ b/service/transcribestreamingservice/api.go @@ -69,6 +69,15 @@ func (c *TranscribeStreamingService) StartMedicalStreamTranscriptionRequest(inpu output.eventStream = es req.Handlers.Sign.PushFront(es.setupInputPipe) + req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{ + Name: "InputPipeCloser", + Fn: func(r *request.Request) { + err := es.closeInputPipe() + if err != nil { + r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error) + } + }, + }) req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{ "Content-Type": "application/vnd.amazon.eventstream", "X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS", @@ -245,6 +254,14 @@ func (es *StartMedicalStreamTranscriptionEventStream) setupInputPipe(r *request. es.inputWriter = inputWriter } +// Closes the input-pipe writer +func (es *StartMedicalStreamTranscriptionEventStream) closeInputPipe() error { + if es.inputWriter != nil { + return es.inputWriter.Close() + } + return nil +} + // Send writes the event to the stream blocking until the event is written. // Returns an error if the event was not written. // @@ -358,8 +375,8 @@ func (es *StartMedicalStreamTranscriptionEventStream) safeClose() { case <-t.C: case <-writeCloseDone: } - if es.inputWriter != nil { - es.inputWriter.Close() + if err := es.closeInputPipe(); err != nil { + es.err.SetError(err) } es.Reader.Close() @@ -431,6 +448,15 @@ func (c *TranscribeStreamingService) StartStreamTranscriptionRequest(input *Star output.eventStream = es req.Handlers.Sign.PushFront(es.setupInputPipe) + req.Handlers.UnmarshalError.PushBackNamed(request.NamedHandler{ + Name: "InputPipeCloser", + Fn: func(r *request.Request) { + err := es.closeInputPipe() + if err != nil { + r.Error = awserr.New(eventstreamapi.InputWriterCloseErrorCode, err.Error(), r.Error) + } + }, + }) req.Handlers.Build.PushBack(request.WithSetRequestHeaders(map[string]string{ "Content-Type": "application/vnd.amazon.eventstream", "X-Amz-Content-Sha256": "STREAMING-AWS4-HMAC-SHA256-EVENTS", @@ -617,6 +643,14 @@ func (es *StartStreamTranscriptionEventStream) setupInputPipe(r *request.Request es.inputWriter = inputWriter } +// Closes the input-pipe writer +func (es *StartStreamTranscriptionEventStream) closeInputPipe() error { + if es.inputWriter != nil { + return es.inputWriter.Close() + } + return nil +} + // Send writes the event to the stream blocking until the event is written. // Returns an error if the event was not written. // @@ -730,8 +764,8 @@ func (es *StartStreamTranscriptionEventStream) safeClose() { case <-t.C: case <-writeCloseDone: } - if es.inputWriter != nil { - es.inputWriter.Close() + if err := es.closeInputPipe(); err != nil { + es.err.SetError(err) } es.Reader.Close() diff --git a/service/transcribestreamingservice/unit_test.go b/service/transcribestreamingservice/unit_test.go new file mode 100644 index 0000000000..7e1d29046f --- /dev/null +++ b/service/transcribestreamingservice/unit_test.go @@ -0,0 +1,69 @@ +// +build go1.10 + +package transcribestreamingservice + +import ( + "bytes" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" +) + +type roundTripFunc func(req *http.Request) *http.Response + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +func newTestClient(fn roundTripFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +func TestStartStreamTranscription_Error(t *testing.T) { + cfg := &aws.Config{ + Region: aws.String("us-west-2"), + Credentials: credentials.AnonymousCredentials, + HTTPClient: newTestClient(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(bytes.NewReader([]byte("{ \"code\" : \"BadRequestException\" }"))), + Header: http.Header{}, + } + }), + } + sess, err := session.NewSession(cfg) + + svc := New(sess) + resp, err := svc.StartStreamTranscription(&StartStreamTranscriptionInput{ + LanguageCode: aws.String(LanguageCodeEnUs), + MediaEncoding: aws.String(MediaEncodingPcm), + MediaSampleRateHertz: aws.Int64(int64(16000)), + }) + if err == nil { + t.Fatalf("expect error, got none") + } else { + if e, a := "BadRequestException", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expected error to be %v, got %v", e, a) + } + } + + n, err := resp.GetStream().inputWriter.Write([]byte("text")) + if err == nil { + t.Fatalf("expected error stating write on closed pipe, got none") + } + + if e, a := "write on closed pipe", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expected error to contain %v, got error as %v", e, a) + } + + if e, a := 0, n; e != a { + t.Fatalf("expected %d bytes to be written on inputWriter, but %v bytes were written", e, a) + } +}