From 3a6e7d9e468dea7c33ae06464f8f1274583d7f5d Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 17 Jul 2022 12:10:22 +0800 Subject: [PATCH 1/2] feat: verify RpcPath on startup --- gateway/config.go | 4 +-- gateway/internal/descriptorsource.go | 34 ++++++++++++++++++++ gateway/internal/descriptorsource_test.go | 29 +++++++++++++++++ gateway/{ => internal}/headerbuilder.go | 5 +-- gateway/{ => internal}/headerbuilder_test.go | 6 ++-- gateway/{ => internal}/requestparser.go | 25 +++++++------- gateway/{ => internal}/requestparser_test.go | 14 ++++---- gateway/readme.md | 4 +-- gateway/server.go | 21 ++++++++++-- 9 files changed, 112 insertions(+), 30 deletions(-) create mode 100644 gateway/internal/descriptorsource.go create mode 100644 gateway/internal/descriptorsource_test.go rename gateway/{ => internal}/headerbuilder.go (76%) rename gateway/{ => internal}/headerbuilder_test.go (76%) rename gateway/{ => internal}/requestparser.go (85%) rename gateway/{ => internal}/requestparser_test.go (82%) diff --git a/gateway/config.go b/gateway/config.go index 684daa599cb8..9c7b1a78b051 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -21,8 +21,8 @@ type ( Method string // Path is the HTTP path. Path string - // Rpc is the gRPC rpc method, with format of package.service/method - Rpc string + // RpcPath is the gRPC rpc method, with format of package.service/method + RpcPath string } // upstream is the configuration for upstream. diff --git a/gateway/internal/descriptorsource.go b/gateway/internal/descriptorsource.go new file mode 100644 index 000000000000..7925ece4f34f --- /dev/null +++ b/gateway/internal/descriptorsource.go @@ -0,0 +1,34 @@ +package internal + +import ( + "fmt" + + "github.com/fullstorydev/grpcurl" + "github.com/jhump/protoreflect/desc" +) + +// GetMethods returns all methods of the given grpcurl.DescriptorSource. +func GetMethods(source grpcurl.DescriptorSource) ([]string, error) { + svcs, err := source.ListServices() + if err != nil { + return nil, err + } + + var methods []string + for _, svc := range svcs { + d, err := source.FindSymbol(svc) + if err != nil { + return nil, err + } + + switch val := d.(type) { + case *desc.ServiceDescriptor: + svcMethods := val.GetMethods() + for _, method := range svcMethods { + methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName())) + } + } + } + + return methods, nil +} diff --git a/gateway/internal/descriptorsource_test.go b/gateway/internal/descriptorsource_test.go new file mode 100644 index 000000000000..8c06607732af --- /dev/null +++ b/gateway/internal/descriptorsource_test.go @@ -0,0 +1,29 @@ +package internal + +import ( + "encoding/base64" + "io/ioutil" + "os" + "testing" + + "github.com/fullstorydev/grpcurl" + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/hash" +) + +const b64pb = `CpgBCgtoZWxsby5wcm90bxIFaGVsbG8iHQoHUmVxdWVzdBISCgRwaW5nGAEgASgJUgRwaW5nIh4KCFJlc3BvbnNlEhIKBHBvbmcYASABKAlSBHBvbmcyMAoFSGVsbG8SJwoEUGluZxIOLmhlbGxvLlJlcXVlc3QaDy5oZWxsby5SZXNwb25zZUIJWgcuL2hlbGxvYgZwcm90bzM=` + +func TestGetMethods(t *testing.T) { + tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(b64pb))) + assert.Nil(t, err) + b, err := base64.StdEncoding.DecodeString(b64pb) + assert.Nil(t, err) + assert.Nil(t, ioutil.WriteFile(tmpfile.Name(), b, os.ModeTemporary)) + defer os.Remove(tmpfile.Name()) + + source, err := grpcurl.DescriptorSourceFromProtoSets(tmpfile.Name()) + assert.Nil(t, err) + methods, err := GetMethods(source) + assert.Nil(t, err) + assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods) +} diff --git a/gateway/headerbuilder.go b/gateway/internal/headerbuilder.go similarity index 76% rename from gateway/headerbuilder.go rename to gateway/internal/headerbuilder.go index 918f53d6085a..f5fa6ec603f0 100644 --- a/gateway/headerbuilder.go +++ b/gateway/internal/headerbuilder.go @@ -1,4 +1,4 @@ -package gateway +package internal import ( "fmt" @@ -11,7 +11,8 @@ const ( metadataPrefix = "gateway-" ) -func buildHeaders(header http.Header) []string { +// BuildHeaders builds the headers for the gateway from HTTP headers. +func BuildHeaders(header http.Header) []string { var headers []string for k, v := range header { diff --git a/gateway/headerbuilder_test.go b/gateway/internal/headerbuilder_test.go similarity index 76% rename from gateway/headerbuilder_test.go rename to gateway/internal/headerbuilder_test.go index 32efd254d4f4..883fd2ca4214 100644 --- a/gateway/headerbuilder_test.go +++ b/gateway/internal/headerbuilder_test.go @@ -1,4 +1,4 @@ -package gateway +package internal import ( "net/http/httptest" @@ -10,12 +10,12 @@ import ( func TestBuildHeadersNoValue(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.Header.Add("a", "b") - assert.Nil(t, buildHeaders(req.Header)) + assert.Nil(t, BuildHeaders(req.Header)) } func TestBuildHeadersWithValues(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.Header.Add("grpc-metadata-a", "b") req.Header.Add("grpc-metadata-b", "b") - assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, buildHeaders(req.Header)) + assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, BuildHeaders(req.Header)) } diff --git a/gateway/requestparser.go b/gateway/internal/requestparser.go similarity index 85% rename from gateway/requestparser.go rename to gateway/internal/requestparser.go index 4de6d06ca1fe..344931d76023 100644 --- a/gateway/requestparser.go +++ b/gateway/internal/requestparser.go @@ -1,4 +1,4 @@ -package gateway +package internal import ( "bytes" @@ -11,17 +11,8 @@ import ( "github.com/zeromicro/go-zero/rest/pathvar" ) -func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) ( - grpcurl.RequestParser, error) { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(m); err != nil { - return nil, err - } - - return grpcurl.NewJSONRequestParser(&buf, resolver), nil -} - -func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) { +// NewRequestParser creates a new request parser from the given http.Request and resolver. +func NewRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) { vars := pathvar.Vars(r) params, err := httpx.GetFormValues(r) if err != nil { @@ -50,3 +41,13 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req return buildJsonRequestParser(m, resolver) } + +func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) ( + grpcurl.RequestParser, error) { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(m); err != nil { + return nil, err + } + + return grpcurl.NewJSONRequestParser(&buf, resolver), nil +} diff --git a/gateway/requestparser_test.go b/gateway/internal/requestparser_test.go similarity index 82% rename from gateway/requestparser_test.go rename to gateway/internal/requestparser_test.go index 3207e072c208..ede871f09ce2 100644 --- a/gateway/requestparser_test.go +++ b/gateway/internal/requestparser_test.go @@ -1,4 +1,4 @@ -package gateway +package internal import ( "net/http/httptest" @@ -11,7 +11,7 @@ import ( func TestNewRequestParserNoVar(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.Nil(t, err) assert.NotNil(t, parser) } @@ -19,14 +19,14 @@ func TestNewRequestParserNoVar(t *testing.T) { func TestNewRequestParserWithVars(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req = pathvar.WithVars(req, map[string]string{"a": "b"}) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.Nil(t, err) assert.NotNil(t, parser) } func TestNewRequestParserNoVarWithBody(t *testing.T) { req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`)) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.Nil(t, err) assert.NotNil(t, parser) } @@ -34,7 +34,7 @@ func TestNewRequestParserNoVarWithBody(t *testing.T) { func TestNewRequestParserWithVarsWithBody(t *testing.T) { req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`)) req = pathvar.WithVars(req, map[string]string{"c": "d"}) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.Nil(t, err) assert.NotNil(t, parser) } @@ -42,14 +42,14 @@ func TestNewRequestParserWithVarsWithBody(t *testing.T) { func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) { req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"`)) req = pathvar.WithVars(req, map[string]string{"c": "d"}) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.NotNil(t, err) assert.Nil(t, parser) } func TestNewRequestParserWithForm(t *testing.T) { req := httptest.NewRequest("GET", "/val?a=b", nil) - parser, err := newRequestParser(req, nil) + parser, err := NewRequestParser(req, nil) assert.Nil(t, err) assert.NotNil(t, parser) } diff --git a/gateway/readme.md b/gateway/readme.md index ea87b1bcca59..1fef0bec0c53 100644 --- a/gateway/readme.md +++ b/gateway/readme.md @@ -35,7 +35,7 @@ Upstreams: Mapping: - Method: get Path: /pingHello/:ping - Rpc: hello.Hello/Ping + RpcPath: hello.Hello/Ping - Grpc: Endpoints: - localhost:8081 @@ -43,7 +43,7 @@ Upstreams: Mapping: - Method: post Path: /pingWorld - Rpc: world.World/Ping + RpcPath: world.World/Ping ``` ## Generate ProtoSet files diff --git a/gateway/server.go b/gateway/server.go index c95cdd172871..bb916b80a732 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "net/http" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/jhump/protoreflect/grpcreflect" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/mr" + "github.com/zeromicro/go-zero/gateway/internal" "github.com/zeromicro/go-zero/rest" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/zrpc" @@ -58,8 +60,23 @@ func (s *Server) build() error { return } + methods, err := internal.GetMethods(source) + if err != nil { + cancel(err) + return + } + + methodSet := make(map[string]struct{}) + for _, m := range methods { + methodSet[m] = struct{}{} + } resolver := grpcurl.AnyResolverFromDescriptorSource(source) for _, m := range up.Mapping { + if _, ok := methodSet[m.RpcPath]; !ok { + cancel(fmt.Errorf("rpc method %s not found", m.RpcPath)) + return + } + writer.Write(rest.Route{ Method: strings.ToUpper(m.Method), Path: m.Path, @@ -82,7 +99,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A Formatter: grpcurl.NewJSONFormatter(true, grpcurl.AnyResolverFromDescriptorSource(source)), } - parser, err := newRequestParser(r, resolver) + parser, err := internal.NewRequestParser(r, resolver) if err != nil { httpx.Error(w, err) return @@ -90,7 +107,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A ctx, can := context.WithTimeout(r.Context(), s.timeout) defer can() - if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.Rpc, buildHeaders(r.Header), + if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header), handler, parser.Next); err != nil { httpx.Error(w, err) } From 12d23783f9082bef84b75de3f1fa3ae900558612 Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 17 Jul 2022 12:20:58 +0800 Subject: [PATCH 2/2] feat: support http header Grpc-Timeout --- gateway/internal/timeout.go | 19 +++++++++++++++++++ gateway/internal/timeout_test.go | 22 ++++++++++++++++++++++ gateway/server.go | 3 ++- 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 gateway/internal/timeout.go create mode 100644 gateway/internal/timeout_test.go diff --git a/gateway/internal/timeout.go b/gateway/internal/timeout.go new file mode 100644 index 000000000000..afe5ca8b31ce --- /dev/null +++ b/gateway/internal/timeout.go @@ -0,0 +1,19 @@ +package internal + +import ( + "net/http" + "time" +) + +const grpcTimeoutHeader = "Grpc-Timeout" + +// GetTimeout returns the timeout from the header, if not set, returns the default timeout. +func GetTimeout(header http.Header, defaultTimeout time.Duration) time.Duration { + if timeout := header.Get(grpcTimeoutHeader); len(timeout) > 0 { + if t, err := time.ParseDuration(timeout); err == nil { + return t + } + } + + return defaultTimeout +} diff --git a/gateway/internal/timeout_test.go b/gateway/internal/timeout_test.go new file mode 100644 index 000000000000..bf681e9d9ae5 --- /dev/null +++ b/gateway/internal/timeout_test.go @@ -0,0 +1,22 @@ +package internal + +import ( + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetTimeout(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(grpcTimeoutHeader, "1s") + timeout := GetTimeout(req.Header, time.Second*5) + assert.Equal(t, time.Second, timeout) +} + +func TestGetTimeoutDefault(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + timeout := GetTimeout(req.Header, time.Second*5) + assert.Equal(t, time.Second*5, timeout) +} diff --git a/gateway/server.go b/gateway/server.go index bb916b80a732..daf090483c66 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -105,7 +105,8 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A return } - ctx, can := context.WithTimeout(r.Context(), s.timeout) + timeout := internal.GetTimeout(r.Header, s.timeout) + ctx, can := context.WithTimeout(r.Context(), timeout) defer can() if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header), handler, parser.Next); err != nil {