Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: allow InTapHandle to return status errors #4365

Merged
merged 5 commits into from May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions internal/transport/controlbuf.go
Expand Up @@ -20,13 +20,17 @@ package transport

import (
"bytes"
"errors"
"fmt"
"runtime"
"strconv"
"sync"
"sync/atomic"

"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/status"
)

var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
Expand Down Expand Up @@ -128,6 +132,14 @@ type cleanupStream struct {

func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM

type earlyAbortStream struct {
streamID uint32
contentSubtype string
status *status.Status
}

func (*earlyAbortStream) isTransportResponseFrame() bool { return false }

type dataFrame struct {
streamID uint32
endStream bool
Expand Down Expand Up @@ -749,6 +761,24 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
return nil
}

func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if l.side == clientSide {
return errors.New("earlyAbortStream not handled on client")
}

headerFields := []hpack.HeaderField{
{Name: ":status", Value: "200"},
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
}

if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
return err
}
return nil
}

func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
if l.side == clientSide {
l.draining = true
Expand Down Expand Up @@ -787,6 +817,8 @@ func (l *loopyWriter) handle(i interface{}) error {
return l.registerStreamHandler(i)
case *cleanupStream:
return l.cleanupStreamHandler(i)
case *earlyAbortStream:
return l.earlyAbortStreamHandler(i)
case *incomingGoAway:
return l.incomingGoAwayHandler(i)
case *dataFrame:
Expand Down
39 changes: 19 additions & 20 deletions internal/transport/http2_server.go
Expand Up @@ -356,26 +356,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if state.data.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
}
if t.inTapHandle != nil {
var err error
info := &tap.Info{
FullMethodName: state.data.method,
}
s.ctx, err = t.inTapHandle(s.ctx, info)
if err != nil {
if logger.V(logLevel) {
logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
}
t.controlBuf.put(&cleanupStream{
streamID: s.id,
rst: true,
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
s.cancel()
return false
}
}
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
Expand Down Expand Up @@ -417,6 +397,25 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.cancel()
return false
}
if t.inTapHandle != nil {
var err error
if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil {
t.mu.Unlock()
if logger.V(logLevel) {
logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
}
stat, ok := status.FromError(err)
if !ok {
stat = status.New(codes.PermissionDenied, err.Error())
}
t.controlBuf.put(&earlyAbortStream{
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
})
return false
}
}
t.activeStreams[streamID] = s
if len(t.activeStreams) == 1 {
t.idle = time.Time{}
Expand Down
5 changes: 5 additions & 0 deletions server.go
Expand Up @@ -376,6 +376,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio

// InTapHandle returns a ServerOption that sets the tap handle for all the server
// transport to be created. Only one can be installed.
//
// Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func InTapHandle(h tap.ServerInHandle) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
if o.inTapHandle != nil {
Expand Down
16 changes: 8 additions & 8 deletions tap/tap.go
Expand Up @@ -37,16 +37,16 @@ type Info struct {
// TODO: More to be added.
}

// ServerInHandle defines the function which runs before a new stream is created
// on the server side. If it returns a non-nil error, the stream will not be
// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM.
// The client will receive an RPC error "code = Unavailable, desc = stream
// terminated by RST_STREAM with error code: REFUSED_STREAM".
// ServerInHandle defines the function which runs before a new stream is
// created on the server side. If it returns a non-nil error, the stream will
// not be created and an error will be returned to the client. If the error
// returned is a status error, that status code and message will be used,
// otherwise PermissionDenied will be the code and err.Error() will be the
// message.
//
// It's intended to be used in situations where you don't want to waste the
// resources to accept the new stream (e.g. rate-limiting). And the content of
// the error will be ignored and won't be sent back to the client. For other
// general usages, please use interceptors.
// resources to accept the new stream (e.g. rate-limiting). For other general
// usages, please use interceptors.
//
// Note that it is executed in the per-connection I/O goroutine(s) instead of
// per-RPC goroutine. Therefore, users should NOT have any
Expand Down
167 changes: 99 additions & 68 deletions test/end2end_test.go
Expand Up @@ -2427,10 +2427,13 @@ type myTap struct {

func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
if info != nil {
if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" {
switch info.FullMethodName {
case "/grpc.testing.TestService/EmptyCall":
t.cnt++
} else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" {
case "/grpc.testing.TestService/UnaryCall":
return nil, fmt.Errorf("tap error")
case "/grpc.testing.TestService/FullDuplexCall":
return nil, status.Errorf(codes.FailedPrecondition, "test custom error")
}
}
return ctx, nil
Expand Down Expand Up @@ -2470,8 +2473,15 @@ func testTap(t *testing.T, e env) {
ResponseSize: 45,
Payload: payload,
}
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied)
}
str, err := tc.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected error creating stream: %v", err)
}
if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition {
t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition)
}
}

Expand Down Expand Up @@ -3559,66 +3569,77 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
}
}

// Tests that the client transparently retries correctly when receiving a
// RST_STREAM with code REFUSED_STREAM.
func (s) TestTransparentRetry(t *testing.T) {
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
// Fails with RST_STREAM / FLOW_CONTROL_ERROR
continue
}
testTransparentRetry(t, e)
}
}

// This test makes sure RPCs are retried times when they receive a RST_STREAM
// with the REFUSED_STREAM error code, which the InTapHandle provokes.
func testTransparentRetry(t *testing.T, e env) {
te := newTest(t, e)
attempts := 0
successAttempt := 2
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
attempts++
if attempts < successAttempt {
return nil, errors.New("not now")
}
return ctx, nil
}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tsc := testpb.NewTestServiceClient(cc)
testCases := []struct {
successAttempt int
failFast bool
errCode codes.Code
failFast bool
errCode codes.Code
}{{
successAttempt: 1,
// success attempt: 1, (stream ID 1)
}, {
successAttempt: 2,
// success attempt: 2, (stream IDs 3, 5)
}, {
successAttempt: 3,
errCode: codes.Unavailable,
// no success attempt (stream IDs 7, 9)
errCode: codes.Unavailable,
}, {
successAttempt: 1,
failFast: true,
// success attempt: 1 (stream ID 11),
failFast: true,
}, {
successAttempt: 2,
failFast: true,
// success attempt: 2 (stream IDs 13, 15),
failFast: true,
}, {
successAttempt: 3,
failFast: true,
errCode: codes.Unavailable,
// no success attempt (stream IDs 17, 19)
failFast: true,
errCode: codes.Unavailable,
}}
for _, tc := range testCases {
attempts = 0
successAttempt = tc.successAttempt

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast))
cancel()
if status.Code(err) != tc.errCode {
t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen. Err: %v", err)
}
defer lis.Close()
server := &httpServer{
headerFields: [][]string{{
":status", "200",
"content-type", "application/grpc",
"grpc-status", "0",
}},
refuseStream: func(i uint32) bool {
switch i {
case 1, 5, 11, 15: // these stream IDs succeed
return false
}
return true // these are refused
},
}
server.start(t, lis)
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("failed to dial due to err: %v", err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client := testpb.NewTestServiceClient(cc)

for i, tc := range testCases {
stream, err := client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("error creating stream due to err: %v", err)
}
code := func(err error) codes.Code {
if err == io.EOF {
return codes.OK
}
return status.Code(err)
}
if _, err := stream.Recv(); code(err) != tc.errCode {
t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode)
}

}
}

Expand Down Expand Up @@ -7043,6 +7064,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {

type httpServer struct {
headerFields [][]string
refuseStream func(uint32) bool
}

func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
Expand Down Expand Up @@ -7090,24 +7112,33 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.

var sid uint32
// Read frames until a header is received.
// Loop until conn is closed and framer returns io.EOF
for {
frame, err := framer.ReadFrame()
if err != nil {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
break
// Read frames until a header is received.
for {
frame, err := framer.ReadFrame()
if err != nil {
if err != io.EOF {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
}
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
if s.refuseStream == nil || !s.refuseStream(sid) {
break
}
framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream)
writer.Flush()
}
}
}
for i, headers := range s.headerFields {
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
for i, headers := range s.headerFields {
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
}
writer.Flush()
}
writer.Flush()
}
}()
}
Expand Down