Skip to content

Commit

Permalink
rebased/feat: add hook extract response data (#41)
Browse files Browse the repository at this point in the history
* checkpoint: add hook stacked functions

* feat: add authz hook to extract response data

* checkpoint

* fix: revert roundtripper

* fix: lint issues

* fix: read hooks config

* fix: add sample rule with hook

* fix: add sample rule with hook with different source
  • Loading branch information
krtkvrm committed Dec 10, 2021
1 parent dc8ead6 commit 9537cec
Show file tree
Hide file tree
Showing 19 changed files with 414 additions and 45 deletions.
20 changes: 14 additions & 6 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"syscall"
"time"

"github.com/odpf/salt/log"
"github.com/odpf/salt/server"
"github.com/odpf/shield/api/handler"
v1 "github.com/odpf/shield/api/handler/v1"
"github.com/odpf/shield/config"
"github.com/odpf/shield/hook"
authz_hook "github.com/odpf/shield/hook/authz"
"github.com/odpf/shield/internal/group"
"github.com/odpf/shield/internal/org"
"github.com/odpf/shield/internal/project"
Expand All @@ -24,9 +24,13 @@ import (
"github.com/odpf/shield/proxy"
blobstore "github.com/odpf/shield/store/blob"
"github.com/odpf/shield/store/postgres"

"github.com/odpf/salt/log"
"github.com/odpf/salt/server"
"github.com/pkg/errors"
"github.com/pkg/profile"
cli "github.com/spf13/cobra"

"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
Expand Down Expand Up @@ -131,15 +135,14 @@ func startProxy(logger log.Logger, appConfig *config.Shield, ctx context.Context
return nil, nil, err
}

// TODO: option to use default http round tripper for http1.1 backends
h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(logger), proxy.NewDirector())
h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(logger, buildHookPipeline(logger)), proxy.NewDirector())

ruleRepo := blobstore.NewRuleRepository(logger, blobFS)
if err := ruleRepo.InitCache(ctx, ruleCacheRefreshDelay); err != nil {
return nil, nil, err
}
cleanUpFunc = append(cleanUpFunc, ruleRepo.Close)
pipeline := buildPipeline(logger, h2cProxy, ruleRepo, appConfig.App.IdentityProxyHeader)
middlewarePipeline := buildMiddlewarePipeline(logger, h2cProxy, ruleRepo, appConfig.App.IdentityProxyHeader)
go func(thisService config.Service, handler http.Handler) {
proxyURL := fmt.Sprintf("%s:%d", thisService.Host, thisService.Port)
logger.Info("starting h2c proxy", "url", proxyURL)
Expand All @@ -162,13 +165,18 @@ func startProxy(logger log.Logger, appConfig *config.Shield, ctx context.Context
logger.Fatal("failed to serve", "err", err)
}
cleanUpProxies = append(cleanUpProxies, proxySrv.Shutdown)
}(service, pipeline)
}(service, middlewarePipeline)
}
time.Sleep(100 * time.Millisecond)
logger.Info("[shield] proxy is up")
return cleanUpFunc, cleanUpProxies, nil
}

func buildHookPipeline(log log.Logger) hook.Service {
rootHook := hook.New()
return authz_hook.New(log, rootHook, rootHook)
}

func waitForTermSignal(ctx context.Context) {
for {
select {
Expand Down
8 changes: 5 additions & 3 deletions cmd/serve_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ import (

"github.com/odpf/shield/middleware/authz"
"github.com/odpf/shield/middleware/basic_auth"

"github.com/odpf/salt/log"
"github.com/odpf/shield/middleware/prefix"
"github.com/odpf/shield/middleware/rulematch"

"github.com/odpf/salt/log"
"github.com/odpf/shield/store"
"github.com/pkg/errors"

"gocloud.dev/blob"
"gocloud.dev/blob/fileblob"
"gocloud.dev/blob/gcsblob"
"gocloud.dev/blob/memblob"
"gocloud.dev/gcp"

"golang.org/x/oauth2/google"
)

// buildPipeline builds middleware sequence
func buildPipeline(logger log.Logger, proxy http.Handler, ruleRepo store.RuleRepository, identityProxyHeader string) http.Handler {
func buildMiddlewarePipeline(logger log.Logger, proxy http.Handler, ruleRepo store.RuleRepository, identityProxyHeader string) http.Handler {
// Note: execution order is bottom up
prefixWare := prefix.New(logger, proxy)
casbinAuthz := authz.New(logger, identityProxyHeader, prefixWare)
Expand Down
144 changes: 144 additions & 0 deletions hook/authz/authz.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package authz

import (
"fmt"
"net/http"
"strings"

"github.com/odpf/shield/hook"
"github.com/odpf/shield/middleware"
"github.com/odpf/shield/pkg/body_extractor"

"github.com/mitchellh/mapstructure"
"github.com/odpf/salt/log"
)

type Authz struct {
log log.Logger

// To go to next hook
next hook.Service

// To skip all the next hooks and just respond back
escape hook.Service
}

func New(log log.Logger, next, escape hook.Service) Authz {
return Authz{
log: log,
next: next,
escape: escape,
}
}

type Config struct {
Action string `yaml:"action" mapstructure:"action"`
Attributes map[string]hook.Attribute `yaml:"attributes" mapstructure:"attributes"`
}

func (a Authz) Info() hook.Info {
return hook.Info{
Name: "authz",
Description: "hook to modify permissions for the resource",
}
}

func (a Authz) ServeHook(res *http.Response, err error) (*http.Response, error) {
if err != nil {
return a.escape.ServeHook(res, err)
}

hookSpec, ok := hook.ExtractHook(res.Request, a.Info().Name)
if !ok {
return a.next.ServeHook(res, nil)
}

config := Config{}
if err := mapstructure.Decode(hookSpec.Config, &config); err != nil {
return a.next.ServeHook(res, nil)
}

attributes := map[string]string{}
for id, attr := range config.Attributes {
bdy, _ := middleware.ExtractRequestBody(res.Request)
bodySource := &res.Body
if attr.Source == string(hook.SourceRequest) {
bodySource = &bdy
}

headerSource := &res.Header
if attr.Source == string(hook.SourceRequest) {
headerSource = &res.Request.Header
}

switch attr.Type {
case hook.AttributeTypeGRPCPayload:
if !strings.HasPrefix(res.Header.Get("Content-Type"), "application/grpc") {
a.log.Error("middleware: not a grpc request", "attr", attr)
return a.escape.ServeHook(res, fmt.Errorf("invalid header for http request: %s", res.Header.Get("Content-Type")))
}

payloadField, err := body_extractor.GRPCPayloadHandler{}.Extract(bodySource, attr.Index)
if err != nil {
a.log.Error("middleware: failed to parse grpc payload", "err", err)
return a.escape.ServeHook(res, fmt.Errorf("unable to parse grpc payload"))
}
attributes[id] = payloadField

a.log.Info("middleware: extracted", "field", payloadField, "attr", attr)
case hook.AttributeTypeJSONPayload:
if attr.Key == "" {
a.log.Error("middleware: payload key field empty")
return a.escape.ServeHook(res, fmt.Errorf("payload key field empty"))
}
payloadField, err := body_extractor.JSONPayloadHandler{}.Extract(bodySource, attr.Key)
if err != nil {
a.log.Error("middleware: failed to parse json payload", "err", err)
return a.escape.ServeHook(res, fmt.Errorf("failed to parse json payload"))
}
attributes[id] = payloadField

a.log.Info("middleware: extracted", "field", payloadField, "attr", attr)
case hook.AttributeTypeHeader:
if attr.Key == "" {
a.log.Error("middleware: header key field empty", "err", err)
return a.escape.ServeHook(res, fmt.Errorf("failed to parse json payload"))
}
headerAttr := headerSource.Get(attr.Key)
if headerAttr == "" {
a.log.Error(fmt.Sprintf("middleware: header %s is empty", attr.Key))
return a.escape.ServeHook(res, fmt.Errorf("failed to parse json payload"))
}

attributes[id] = headerAttr
a.log.Info("middleware: extracted", "field", headerAttr, "attr", attr)

case hook.AttributeTypeQuery:
if attr.Key == "" {
a.log.Error("middleware: query key field empty")
return a.escape.ServeHook(res, fmt.Errorf("failed to parse json payload"))
}
queryAttr := res.Request.URL.Query().Get(attr.Key)
if queryAttr == "" {
a.log.Error(fmt.Sprintf("middleware: query %s is empty", attr.Key))
return a.escape.ServeHook(res, fmt.Errorf("failed to parse json payload"))
}

attributes[id] = queryAttr
a.log.Info("middleware: extracted", "field", queryAttr, "attr", attr)
default:
a.log.Error("middleware: unknown attribute type", "attr", attr)
return a.escape.ServeHook(res, fmt.Errorf("unknown attribute type: %v", attr))
}
}

//Change after merging PR#32
//paramMap, _ := middleware.ExtractPathParams(req)
//for key, value := range paramMap {
// permissionAttributes[key] = value
//}

// use attributes to modify authz

return a.next.ServeHook(res, nil)
}
64 changes: 64 additions & 0 deletions hook/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package hook

import (
"net/http"

"github.com/odpf/shield/middleware"
"github.com/odpf/shield/structs"
)

type Service interface {
Info() Info
ServeHook(res *http.Response, err error) (*http.Response, error)
}

type Info struct {
Name string
Description string
}

const (
AttributeTypeJSONPayload AttributeType = "json_payload"
AttributeTypeGRPCPayload AttributeType = "grpc_payload"
AttributeTypeQuery AttributeType = "query"
AttributeTypeHeader AttributeType = "header"

SourceRequest AttributeType = "request"
SourceResponse AttributeType = "response"
)

type AttributeType string

type Attribute struct {
Key string `yaml:"key" mapstructure:"key"`
Type AttributeType `yaml:"type" mapstructure:"type"`
Index int `yaml:"index" mapstructure:"index"` // proto index
Source string `yaml:"source" mapstructure:"source"`
}

func ExtractHook(r *http.Request, name string) (structs.HookSpec, bool) {
rl, ok := middleware.ExtractRule(r)
if !ok {
return structs.HookSpec{}, false
}
return rl.Hooks.Get(name)
}

type Hook struct{}

func New() Hook {
return Hook{}
}

func (h Hook) Info() Info {
return Info{}
}

func (h Hook) ServeHook(res *http.Response, err error) (*http.Response, error) {
if err != nil {
res.StatusCode = http.StatusInternalServerError
// TODO: clear or add error body as well
}

return res, nil
}
17 changes: 17 additions & 0 deletions integration/fixtures/ruleset.grpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ rules:
users:
- user: user
password: $apr1$RfxoV6GP$.GsGgD580H5FOuUfTzKZh0
hooks:
- name: authz
config:
action: some_action
attributes:
project_resp:
index: 1
type: grpc_payload

- name: basic_auth_2
path: "/helloworld.Greeter/SayHello"
method: "POST"
Expand All @@ -29,3 +38,11 @@ rules:
name:
type: grpc_payload
index: 1
hooks:
- name: authz
config:
action: some_action
attributes:
project_resp:
index: 1
type: grpc_payload
12 changes: 10 additions & 2 deletions integration/fixtures/ruleset.rest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ rules:
- name: prefix
config:
strip: "/basic-authz"

- name: basic_auth
config:
users:
Expand All @@ -52,4 +51,13 @@ rules:
attributes:
project:
type: json_payload
key: project
key: project
hooks:
- name: authz
config:
action: authz_action
attributes:
project:
key: project
type: json_payload
source: request
17 changes: 10 additions & 7 deletions integration/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/odpf/shield/hook"
"github.com/odpf/shield/integration/fixtures/helloworld"
"google.golang.org/grpc"

"github.com/odpf/salt/log"
"github.com/odpf/shield/proxy"
blobstore "github.com/odpf/shield/store/blob"

"github.com/odpf/salt/log"
"github.com/stretchr/testify/assert"

"gocloud.dev/blob/fileblob"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"

"google.golang.org/grpc"
)

const (
Expand All @@ -39,7 +41,8 @@ func TestGRPCProxyHelloWorld(t *testing.T) {
t.Fatal(err)
}

h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(log.NewNoop()), proxy.NewDirector())
responseHooks := hookPipeline(log.NewNoop())
h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(log.NewNoop(), responseHooks), proxy.NewDirector())
ruleRepo := blobstore.NewRuleRepository(log.NewNoop(), blobFS)
if err := ruleRepo.InitCache(baseCtx, time.Minute); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -137,7 +140,7 @@ func BenchmarkGRPCProxyHelloWorld(b *testing.B) {
b.Fatal(err)
}

h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(log.NewNoop()), proxy.NewDirector())
h2cProxy := proxy.NewH2c(proxy.NewH2cRoundTripper(log.NewNoop(), hook.New()), proxy.NewDirector())
ruleRepo := blobstore.NewRuleRepository(log.NewNoop(), blobFS)
if err := ruleRepo.InitCache(baseCtx, time.Minute); err != nil {
b.Fatal(err)
Expand Down

0 comments on commit 9537cec

Please sign in to comment.