Skip to content

Commit

Permalink
transport: allow InTapHandle to return status errors (#4365)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley committed May 7, 2021
1 parent aff517b commit 328b1d1
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 96 deletions.
32 changes: 32 additions & 0 deletions internal/transport/controlbuf.go
Expand Up @@ -20,13 +20,17 @@ package transport

import (
"bytes"
"errors"
"fmt"
"runtime"
"strconv"
"sync"
"sync/atomic"

"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/status"
)

var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
Expand Down Expand Up @@ -128,6 +132,14 @@ type cleanupStream struct {

func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM

type earlyAbortStream struct {
streamID uint32
contentSubtype string
status *status.Status
}

func (*earlyAbortStream) isTransportResponseFrame() bool { return false }

type dataFrame struct {
streamID uint32
endStream bool
Expand Down Expand Up @@ -749,6 +761,24 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
return nil
}

func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if l.side == clientSide {
return errors.New("earlyAbortStream not handled on client")
}

headerFields := []hpack.HeaderField{
{Name: ":status", Value: "200"},
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
}

if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
return err
}
return nil
}

func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
if l.side == clientSide {
l.draining = true
Expand Down Expand Up @@ -787,6 +817,8 @@ func (l *loopyWriter) handle(i interface{}) error {
return l.registerStreamHandler(i)
case *cleanupStream:
return l.cleanupStreamHandler(i)
case *earlyAbortStream:
return l.earlyAbortStreamHandler(i)
case *incomingGoAway:
return l.incomingGoAwayHandler(i)
case *dataFrame:
Expand Down
39 changes: 19 additions & 20 deletions internal/transport/http2_server.go
Expand Up @@ -356,26 +356,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if state.data.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
}
if t.inTapHandle != nil {
var err error
info := &tap.Info{
FullMethodName: state.data.method,
}
s.ctx, err = t.inTapHandle(s.ctx, info)
if err != nil {
if logger.V(logLevel) {
logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
}
t.controlBuf.put(&cleanupStream{
streamID: s.id,
rst: true,
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
s.cancel()
return false
}
}
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
Expand Down Expand Up @@ -417,6 +397,25 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.cancel()
return false
}
if t.inTapHandle != nil {
var err error
if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil {
t.mu.Unlock()
if logger.V(logLevel) {
logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
}
stat, ok := status.FromError(err)
if !ok {
stat = status.New(codes.PermissionDenied, err.Error())
}
t.controlBuf.put(&earlyAbortStream{
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
})
return false
}
}
t.activeStreams[streamID] = s
if len(t.activeStreams) == 1 {
t.idle = time.Time{}
Expand Down
5 changes: 5 additions & 0 deletions server.go
Expand Up @@ -418,6 +418,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio

// InTapHandle returns a ServerOption that sets the tap handle for all the server
// transport to be created. Only one can be installed.
//
// Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func InTapHandle(h tap.ServerInHandle) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
if o.inTapHandle != nil {
Expand Down
16 changes: 8 additions & 8 deletions tap/tap.go
Expand Up @@ -37,16 +37,16 @@ type Info struct {
// TODO: More to be added.
}

// ServerInHandle defines the function which runs before a new stream is created
// on the server side. If it returns a non-nil error, the stream will not be
// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM.
// The client will receive an RPC error "code = Unavailable, desc = stream
// terminated by RST_STREAM with error code: REFUSED_STREAM".
// ServerInHandle defines the function which runs before a new stream is
// created on the server side. If it returns a non-nil error, the stream will
// not be created and an error will be returned to the client. If the error
// returned is a status error, that status code and message will be used,
// otherwise PermissionDenied will be the code and err.Error() will be the
// message.
//
// It's intended to be used in situations where you don't want to waste the
// resources to accept the new stream (e.g. rate-limiting). And the content of
// the error will be ignored and won't be sent back to the client. For other
// general usages, please use interceptors.
// resources to accept the new stream (e.g. rate-limiting). For other general
// usages, please use interceptors.
//
// Note that it is executed in the per-connection I/O goroutine(s) instead of
// per-RPC goroutine. Therefore, users should NOT have any
Expand Down
167 changes: 99 additions & 68 deletions test/end2end_test.go
Expand Up @@ -2507,10 +2507,13 @@ type myTap struct {

func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
if info != nil {
if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" {
switch info.FullMethodName {
case "/grpc.testing.TestService/EmptyCall":
t.cnt++
} else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" {
case "/grpc.testing.TestService/UnaryCall":
return nil, fmt.Errorf("tap error")
case "/grpc.testing.TestService/FullDuplexCall":
return nil, status.Errorf(codes.FailedPrecondition, "test custom error")
}
}
return ctx, nil
Expand Down Expand Up @@ -2550,8 +2553,15 @@ func testTap(t *testing.T, e env) {
ResponseSize: 45,
Payload: payload,
}
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied)
}
str, err := tc.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected error creating stream: %v", err)
}
if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition {
t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition)
}
}

Expand Down Expand Up @@ -3639,66 +3649,77 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
}
}

// Tests that the client transparently retries correctly when receiving a
// RST_STREAM with code REFUSED_STREAM.
func (s) TestTransparentRetry(t *testing.T) {
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
// Fails with RST_STREAM / FLOW_CONTROL_ERROR
continue
}
testTransparentRetry(t, e)
}
}

// This test makes sure RPCs are retried times when they receive a RST_STREAM
// with the REFUSED_STREAM error code, which the InTapHandle provokes.
func testTransparentRetry(t *testing.T, e env) {
te := newTest(t, e)
attempts := 0
successAttempt := 2
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
attempts++
if attempts < successAttempt {
return nil, errors.New("not now")
}
return ctx, nil
}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tsc := testpb.NewTestServiceClient(cc)
testCases := []struct {
successAttempt int
failFast bool
errCode codes.Code
failFast bool
errCode codes.Code
}{{
successAttempt: 1,
// success attempt: 1, (stream ID 1)
}, {
successAttempt: 2,
// success attempt: 2, (stream IDs 3, 5)
}, {
successAttempt: 3,
errCode: codes.Unavailable,
// no success attempt (stream IDs 7, 9)
errCode: codes.Unavailable,
}, {
successAttempt: 1,
failFast: true,
// success attempt: 1 (stream ID 11),
failFast: true,
}, {
successAttempt: 2,
failFast: true,
// success attempt: 2 (stream IDs 13, 15),
failFast: true,
}, {
successAttempt: 3,
failFast: true,
errCode: codes.Unavailable,
// no success attempt (stream IDs 17, 19)
failFast: true,
errCode: codes.Unavailable,
}}
for _, tc := range testCases {
attempts = 0
successAttempt = tc.successAttempt

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast))
cancel()
if status.Code(err) != tc.errCode {
t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen. Err: %v", err)
}
defer lis.Close()
server := &httpServer{
headerFields: [][]string{{
":status", "200",
"content-type", "application/grpc",
"grpc-status", "0",
}},
refuseStream: func(i uint32) bool {
switch i {
case 1, 5, 11, 15: // these stream IDs succeed
return false
}
return true // these are refused
},
}
server.start(t, lis)
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("failed to dial due to err: %v", err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client := testpb.NewTestServiceClient(cc)

for i, tc := range testCases {
stream, err := client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("error creating stream due to err: %v", err)
}
code := func(err error) codes.Code {
if err == io.EOF {
return codes.OK
}
return status.Code(err)
}
if _, err := stream.Recv(); code(err) != tc.errCode {
t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode)
}

}
}

Expand Down Expand Up @@ -7191,6 +7212,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {

type httpServer struct {
headerFields [][]string
refuseStream func(uint32) bool
}

func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
Expand Down Expand Up @@ -7238,24 +7260,33 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.

var sid uint32
// Read frames until a header is received.
// Loop until conn is closed and framer returns io.EOF
for {
frame, err := framer.ReadFrame()
if err != nil {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
break
// Read frames until a header is received.
for {
frame, err := framer.ReadFrame()
if err != nil {
if err != io.EOF {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
}
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
if s.refuseStream == nil || !s.refuseStream(sid) {
break
}
framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream)
writer.Flush()
}
}
}
for i, headers := range s.headerFields {
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
for i, headers := range s.headerFields {
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
}
writer.Flush()
}
writer.Flush()
}
}()
}
Expand Down

0 comments on commit 328b1d1

Please sign in to comment.