Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: verify RpcPath on startup #2159

Merged
merged 2 commits into from Jul 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions gateway/config.go
Expand Up @@ -21,8 +21,8 @@ type (
Method string
// Path is the HTTP path.
Path string
// Rpc is the gRPC rpc method, with format of package.service/method
Rpc string
// RpcPath is the gRPC rpc method, with format of package.service/method
RpcPath string
}

// upstream is the configuration for upstream.
Expand Down
34 changes: 34 additions & 0 deletions gateway/internal/descriptorsource.go
@@ -0,0 +1,34 @@
package internal

import (
"fmt"

"github.com/fullstorydev/grpcurl"
"github.com/jhump/protoreflect/desc"
)

// GetMethods returns all methods of the given grpcurl.DescriptorSource.
func GetMethods(source grpcurl.DescriptorSource) ([]string, error) {
svcs, err := source.ListServices()
if err != nil {
return nil, err
}

var methods []string
for _, svc := range svcs {
d, err := source.FindSymbol(svc)
if err != nil {
return nil, err
}

switch val := d.(type) {
case *desc.ServiceDescriptor:
svcMethods := val.GetMethods()
for _, method := range svcMethods {
methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName()))
}
}
}

return methods, nil
}
29 changes: 29 additions & 0 deletions gateway/internal/descriptorsource_test.go
@@ -0,0 +1,29 @@
package internal

import (
"encoding/base64"
"io/ioutil"
"os"
"testing"

"github.com/fullstorydev/grpcurl"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/hash"
)

const b64pb = `CpgBCgtoZWxsby5wcm90bxIFaGVsbG8iHQoHUmVxdWVzdBISCgRwaW5nGAEgASgJUgRwaW5nIh4KCFJlc3BvbnNlEhIKBHBvbmcYASABKAlSBHBvbmcyMAoFSGVsbG8SJwoEUGluZxIOLmhlbGxvLlJlcXVlc3QaDy5oZWxsby5SZXNwb25zZUIJWgcuL2hlbGxvYgZwcm90bzM=`

func TestGetMethods(t *testing.T) {
tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(b64pb)))
assert.Nil(t, err)
b, err := base64.StdEncoding.DecodeString(b64pb)
assert.Nil(t, err)
assert.Nil(t, ioutil.WriteFile(tmpfile.Name(), b, os.ModeTemporary))
defer os.Remove(tmpfile.Name())

source, err := grpcurl.DescriptorSourceFromProtoSets(tmpfile.Name())
assert.Nil(t, err)
methods, err := GetMethods(source)
assert.Nil(t, err)
assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods)
}
@@ -1,4 +1,4 @@
package gateway
package internal

import (
"fmt"
Expand All @@ -11,7 +11,8 @@ const (
metadataPrefix = "gateway-"
)

func buildHeaders(header http.Header) []string {
// BuildHeaders builds the headers for the gateway from HTTP headers.
func BuildHeaders(header http.Header) []string {
var headers []string

for k, v := range header {
Expand Down
@@ -1,4 +1,4 @@
package gateway
package internal

import (
"net/http/httptest"
Expand All @@ -10,12 +10,12 @@ import (
func TestBuildHeadersNoValue(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Add("a", "b")
assert.Nil(t, buildHeaders(req.Header))
assert.Nil(t, BuildHeaders(req.Header))
}

func TestBuildHeadersWithValues(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Add("grpc-metadata-a", "b")
req.Header.Add("grpc-metadata-b", "b")
assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, buildHeaders(req.Header))
assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, BuildHeaders(req.Header))
}
25 changes: 13 additions & 12 deletions gateway/requestparser.go → gateway/internal/requestparser.go
@@ -1,4 +1,4 @@
package gateway
package internal

import (
"bytes"
Expand All @@ -11,17 +11,8 @@ import (
"github.com/zeromicro/go-zero/rest/pathvar"
)

func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
grpcurl.RequestParser, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(m); err != nil {
return nil, err
}

return grpcurl.NewJSONRequestParser(&buf, resolver), nil
}

func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
// NewRequestParser creates a new request parser from the given http.Request and resolver.
func NewRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
vars := pathvar.Vars(r)
params, err := httpx.GetFormValues(r)
if err != nil {
Expand Down Expand Up @@ -50,3 +41,13 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req

return buildJsonRequestParser(m, resolver)
}

func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
grpcurl.RequestParser, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(m); err != nil {
return nil, err
}

return grpcurl.NewJSONRequestParser(&buf, resolver), nil
}
@@ -1,4 +1,4 @@
package gateway
package internal

import (
"net/http/httptest"
Expand All @@ -11,45 +11,45 @@ import (

func TestNewRequestParserNoVar(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.Nil(t, err)
assert.NotNil(t, parser)
}

func TestNewRequestParserWithVars(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req = pathvar.WithVars(req, map[string]string{"a": "b"})
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.Nil(t, err)
assert.NotNil(t, parser)
}

func TestNewRequestParserNoVarWithBody(t *testing.T) {
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.Nil(t, err)
assert.NotNil(t, parser)
}

func TestNewRequestParserWithVarsWithBody(t *testing.T) {
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
req = pathvar.WithVars(req, map[string]string{"c": "d"})
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.Nil(t, err)
assert.NotNil(t, parser)
}

func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) {
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"`))
req = pathvar.WithVars(req, map[string]string{"c": "d"})
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.NotNil(t, err)
assert.Nil(t, parser)
}

func TestNewRequestParserWithForm(t *testing.T) {
req := httptest.NewRequest("GET", "/val?a=b", nil)
parser, err := newRequestParser(req, nil)
parser, err := NewRequestParser(req, nil)
assert.Nil(t, err)
assert.NotNil(t, parser)
}
19 changes: 19 additions & 0 deletions gateway/internal/timeout.go
@@ -0,0 +1,19 @@
package internal

import (
"net/http"
"time"
)

const grpcTimeoutHeader = "Grpc-Timeout"

// GetTimeout returns the timeout from the header, if not set, returns the default timeout.
func GetTimeout(header http.Header, defaultTimeout time.Duration) time.Duration {
if timeout := header.Get(grpcTimeoutHeader); len(timeout) > 0 {
if t, err := time.ParseDuration(timeout); err == nil {
return t
}
}

return defaultTimeout
}
22 changes: 22 additions & 0 deletions gateway/internal/timeout_test.go
@@ -0,0 +1,22 @@
package internal

import (
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestGetTimeout(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(grpcTimeoutHeader, "1s")
timeout := GetTimeout(req.Header, time.Second*5)
assert.Equal(t, time.Second, timeout)
}

func TestGetTimeoutDefault(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
timeout := GetTimeout(req.Header, time.Second*5)
assert.Equal(t, time.Second*5, timeout)
}
4 changes: 2 additions & 2 deletions gateway/readme.md
Expand Up @@ -35,15 +35,15 @@ Upstreams:
Mapping:
- Method: get
Path: /pingHello/:ping
Rpc: hello.Hello/Ping
RpcPath: hello.Hello/Ping
- Grpc:
Endpoints:
- localhost:8081
# reflection mode, no ProtoSet settings
Mapping:
- Method: post
Path: /pingWorld
Rpc: world.World/Ping
RpcPath: world.World/Ping
```

## Generate ProtoSet files
Expand Down
24 changes: 21 additions & 3 deletions gateway/server.go
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"fmt"
"net/http"
"strings"
"time"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/jhump/protoreflect/grpcreflect"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mr"
"github.com/zeromicro/go-zero/gateway/internal"
"github.com/zeromicro/go-zero/rest"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/zrpc"
Expand Down Expand Up @@ -58,8 +60,23 @@ func (s *Server) build() error {
return
}

methods, err := internal.GetMethods(source)
if err != nil {
cancel(err)
return
}

methodSet := make(map[string]struct{})
for _, m := range methods {
methodSet[m] = struct{}{}
}
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
for _, m := range up.Mapping {
if _, ok := methodSet[m.RpcPath]; !ok {
cancel(fmt.Errorf("rpc method %s not found", m.RpcPath))
return
}

writer.Write(rest.Route{
Method: strings.ToUpper(m.Method),
Path: m.Path,
Expand All @@ -82,15 +99,16 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
Formatter: grpcurl.NewJSONFormatter(true,
grpcurl.AnyResolverFromDescriptorSource(source)),
}
parser, err := newRequestParser(r, resolver)
parser, err := internal.NewRequestParser(r, resolver)
if err != nil {
httpx.Error(w, err)
return
}

ctx, can := context.WithTimeout(r.Context(), s.timeout)
timeout := internal.GetTimeout(r.Header, s.timeout)
ctx, can := context.WithTimeout(r.Context(), timeout)
defer can()
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.Rpc, buildHeaders(r.Header),
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header),
handler, parser.Next); err != nil {
httpx.Error(w, err)
}
Expand Down