Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metadata: validate metadata keys and values #4886

Merged
merged 26 commits into from Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions stream.go
Expand Up @@ -1481,8 +1481,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) {
return
}
if err := imetadata.Validate(md); err != nil {
logger.Errorf("stream: failed to set trailer, err: %v", err)
return
logger.Errorf("stream: failed to validate md when setting trailer, err: %v", err)
}
ss.s.SetTrailer(md)
}
Expand Down
54 changes: 26 additions & 28 deletions test/metadata_test.go
Expand Up @@ -20,6 +20,7 @@ package test

import (
"context"
"fmt"
"io"
"reflect"
"testing"
Expand All @@ -38,22 +39,27 @@ func (s) TestInvalidMetadata(t *testing.T) {
tests := []struct {
md metadata.MD
want error
recv error
}{
{
md: map[string][]string{string(rune(0x19)): {"testVal"}},
want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"),
recv: status.Error(codes.Internal, "invalid header field name \"\\x19\""),
},
{
md: map[string][]string{"test": {string(rune(0x19))}},
want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"),
recv: status.Error(codes.Internal, "invalid header field value \"\\x19\""),
},
{
md: map[string][]string{"test-bin": {string(rune(0x19))}},
want: nil,
recv: io.EOF,
},
{
md: map[string][]string{"test": {"value"}},
want: nil,
recv: io.EOF,
},
}

Expand All @@ -63,28 +69,20 @@ func (s) TestInvalidMetadata(t *testing.T) {
return &testpb.Empty{}, nil
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
test := tests[testNum]
testNum = testNum + 1
if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) {
t.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) {
t.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
stream.SetTrailer(test.md)
err = stream.Send(&testpb.StreamingOutputCallResponse{})
if err != nil {
return err
}
_, err := stream.Recv()
if err != nil {
return err
}
test := tests[testNum]
testNum = testNum + 1
Patrick0308 marked this conversation as resolved.
Show resolved Hide resolved
if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
stream.SetTrailer(test.md)
return nil
},
}
if err := ss.Start(nil); err != nil {
Expand All @@ -98,23 +96,23 @@ func (s) TestInvalidMetadata(t *testing.T) {

ctx = metadata.NewOutgoingContext(ctx, test.md)
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
Patrick0308 marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
}

for range tests {
// call the stream server's api to drive the server-side unit testing
for _, test := range tests {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
stream, err := ss.Client.FullDuplexCall(ctx, grpc.WaitForReady(true))
Patrick0308 marked this conversation as resolved.
Show resolved Hide resolved
defer cancel()
if err != nil {
t.Fatalf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
Patrick0308 marked this conversation as resolved.
Show resolved Hide resolved
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("call ss.Client stream Send(nil) will success but got err :%v", err)
t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("stream.Recv() = _, %v", err)
if _, err := stream.Recv(); !reflect.DeepEqual(test.recv, err) {
t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
}
stream.CloseSend()
}
}