Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topdown+wasm: Verifying host based on allow_net allowlist in built-in functions #4152

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions ast/compile.go
Expand Up @@ -336,6 +336,11 @@ func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
return c
}

// Capabilities returns the capabilities enabled during compilation.
func (c *Compiler) Capabilities() *Capabilities {
return c.capabilities
}

// WithDebug sets where debug messages are written to. Passing `nil` has no
// effect.
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
Expand Down
2 changes: 2 additions & 0 deletions docs/content/deployments.md
Expand Up @@ -445,6 +445,8 @@ Not providing a capabilities file, or providing a file without an `allow_net` ke

Note that the metaschemas http://json-schema.org/draft-04/schema, http://json-schema.org/draft-06/schema, and http://json-schema.org/draft-07/schema, are always available, even without network access.

Similarly, the `allow_net` capability restricts what hosts the `http.send` built-in function may send requests to, and what hosts the `net.lookup_ip_addr` built-in function may resolve IP addresses for.

### Future keywords

The availability of future keywords in an OPA version can also be controlled using the capabilities file:
Expand Down
1 change: 1 addition & 0 deletions features/wasm/wasm.go
Expand Up @@ -65,6 +65,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error)
Seed: opts.Seed,
InterQueryBuiltinCache: opts.InterQueryBuiltinCache,
PrintHook: opts.PrintHook,
Capabilities: opts.Capabilities,
}

res, err := o.opa.Eval(ctx, evalOptions)
Expand Down
2 changes: 2 additions & 0 deletions internal/rego/opa/options.go
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
Expand All @@ -23,4 +24,5 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}
8 changes: 7 additions & 1 deletion internal/wasm/sdk/internal/wasm/bindings.go
Expand Up @@ -81,7 +81,12 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) {
}

// Reset is called in Eval before using the builtinDispatcher.
func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) {
func (d *builtinDispatcher) Reset(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook,
capabilities *ast.Capabilities) {
if ns.IsZero() {
ns = time.Now()
}
Expand All @@ -103,6 +108,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T
ParentID: 0,
InterQueryBuiltinCache: iqbCache,
PrintHook: ph,
Capabilities: capabilities,
}

}
Expand Down
2 changes: 1 addition & 1 deletion internal/wasm/sdk/internal/wasm/pool_test.go
Expand Up @@ -176,7 +176,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p
toRelease = append(toRelease, vm)

cfg, _ := cache.ParseCachingConfig(nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil, nil)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down
12 changes: 7 additions & 5 deletions internal/wasm/sdk/internal/wasm/vm.go
Expand Up @@ -275,9 +275,10 @@ func (i *VM) Eval(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
if i.abiMinorVersion < int32(2) {
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph)
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph, capabilities)
}

metrics.Timer("wasm_vm_eval").Start()
Expand Down Expand Up @@ -328,7 +329,7 @@ func (i *VM) Eval(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

metrics.Timer("wasm_vm_eval_call").Start()
resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr)
Expand Down Expand Up @@ -357,7 +358,8 @@ func (i *VM) evalCompat(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
metrics.Timer("wasm_vm_eval").Start()
defer metrics.Timer("wasm_vm_eval").Stop()

Expand All @@ -367,7 +369,7 @@ func (i *VM) evalCompat(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

err := i.setHeapState(ctx, i.evalHeapPtr)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion internal/wasm/sdk/opa/opa.go
Expand Up @@ -12,6 +12,7 @@ import (
"sync"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm"
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
Expand Down Expand Up @@ -165,6 +166,7 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}

// Eval evaluates the policy with the given input, returning the
Expand All @@ -188,7 +190,8 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) {

defer o.pool.Release(instance, m)

result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook)
result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache,
opts.PrintHook, opts.Capabilities)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions rego/rego.go
Expand Up @@ -118,6 +118,7 @@ type EvalContext struct {
resolvers []refResolver
sortSets bool
printHook print.Hook
capabilities *ast.Capabilities
}

// EvalOption defines a function to set an option on an EvalConfig
Expand Down Expand Up @@ -311,6 +312,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption
earlyExit: true,
resolvers: pq.r.resolvers,
printHook: pq.r.printHook,
capabilities: pq.r.capabilities,
}

for _, o := range options {
Expand Down Expand Up @@ -1969,6 +1971,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro
Seed: ectx.seed,
InterQueryBuiltinCache: ectx.interQueryBuiltinCache,
PrintHook: ectx.printHook,
Capabilities: ectx.capabilities,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -2075,6 +2078,7 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR
instrumentation: r.instrumentation,
indexing: true,
resolvers: r.resolvers,
capabilities: r.capabilities,
}

disableInlining := r.disableInlining
Expand Down
52 changes: 52 additions & 0 deletions rego/rego_wasmtarget_test.go
Expand Up @@ -12,6 +12,8 @@ import (
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -338,3 +340,53 @@ func TestEvalWasmWithInterQueryCache(t *testing.T) {
t.Fatal("Expected server to be called only once")
}
}

func TestEvalWasmWithHTTPAllowNet(t *testing.T) {
anderseknert marked this conversation as resolved.
Show resolved Hide resolved
var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"x": 1}`))
}))
defer ts.Close()

serverUrl, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
serverHost := strings.Split(serverUrl.Host, ":")[0]

query := fmt.Sprintf(`http.send({"method": "get", "url": "%s", "force_json_decode": true, "cache": true})`, ts.URL)
capabilities := ast.CapabilitiesForThisVersion()
capabilities.AllowNet = []string{"example.com"}

// add an inter-query cache
config, _ := cache.ParseCachingConfig(nil)
interQueryCache := cache.NewInterQueryCache(config)

ctx := context.Background()
// StrictBuiltinErrors(true) has no effect when target is 'wasm'
// this request should be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 0 {
t.Fatal("Expected server to not be called")
}

capabilities.AllowNet = []string{serverHost}

// eval again with same query
// this request should not be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 1 {
t.Fatal("Expected server to never be called")
}
}
1 change: 1 addition & 0 deletions topdown/builtins.go
Expand Up @@ -52,6 +52,7 @@ type (
PrintHook print.Hook // provides callback function to use for printing
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
rand *rand.Rand // randomization source for non-security-sensitive operations
Capabilities *ast.Capabilities
}

// BuiltinFunc defines an interface for implementing built-in functions.
Expand Down
6 changes: 6 additions & 0 deletions topdown/eval.go
Expand Up @@ -701,6 +701,11 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
parentID = e.parent.queryID
}

var capabilities *ast.Capabilities
if e.compiler != nil {
capabilities = e.compiler.Capabilities()
}

bctx := BuiltinContext{
Context: e.ctx,
Metrics: e.metrics,
Expand All @@ -717,6 +722,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ParentID: parentID,
PrintHook: e.printHook,
DistributedTracingOpts: e.tracingOpts,
Capabilities: capabilities,
}

eval := evalBuiltin{
Expand Down
60 changes: 47 additions & 13 deletions topdown/http.go
Expand Up @@ -264,6 +264,36 @@ func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transp
return true, rawURL, tr
}

func verifyHost(bctx BuiltinContext, host string) error {
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
return nil
}

for _, allowed := range bctx.Capabilities.AllowNet {
if allowed == host {
return nil
}
}

return fmt.Errorf("unallowed host: %s", host)
}

func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error {
// Eager return to avoid unnecessary URL parsing
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
return nil
}

parsedURL, err := url.Parse(unverifiedURL)
if err != nil {
return err
}

host := strings.Split(parsedURL.Host, ":")[0]

return verifyHost(bctx, host)
}

func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) {
var url string
var method string
Expand Down Expand Up @@ -305,7 +335,7 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
var strVal string

if s, ok := obj.Get(val).Value.(ast.String); ok {
strVal = string(s)
strVal = strings.Trim(string(s), "\"")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK given the existing code, but something here doesn't seem correct to me. Why are there extra "?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I don't get this either; but I didn't want to change existing functionality.

} else {
// Most parameters are strings, so consolidate the type checking.
switch key {
Expand All @@ -328,9 +358,13 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt

switch key {
case "method":
method = strings.ToUpper(strings.Trim(strVal, "\""))
method = strings.ToUpper(strVal)
case "url":
url = strings.Trim(strVal, "\"")
err := verifyURLHost(bctx, strVal)
if err != nil {
return nil, nil, err
}
url = strVal
case "enable_redirect":
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
Expand All @@ -357,25 +391,25 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
}
tlsUseSystemCerts = &tempTLSUseSystemCerts
case "tls_ca_cert":
tlsCaCert = bytes.Trim([]byte(strVal), "\"")
tlsCaCert = []byte(strVal)
case "tls_ca_cert_file":
tlsCaCertFile = strings.Trim(strVal, "\"")
tlsCaCertFile = strVal
case "tls_ca_cert_env_variable":
tlsCaCertEnvVar = strings.Trim(strVal, "\"")
tlsCaCertEnvVar = strVal
case "tls_client_cert":
tlsClientCert = bytes.Trim([]byte(strVal), "\"")
tlsClientCert = []byte(strVal)
case "tls_client_cert_file":
tlsClientCertFile = strings.Trim(strVal, "\"")
tlsClientCertFile = strVal
case "tls_client_cert_env_variable":
tlsClientCertEnvVar = strings.Trim(strVal, "\"")
tlsClientCertEnvVar = strVal
case "tls_client_key":
tlsClientKey = bytes.Trim([]byte(strVal), "\"")
tlsClientKey = []byte(strVal)
case "tls_client_key_file":
tlsClientKeyFile = strings.Trim(strVal, "\"")
tlsClientKeyFile = strVal
case "tls_client_key_env_variable":
tlsClientKeyEnvVar = strings.Trim(strVal, "\"")
tlsClientKeyEnvVar = strVal
case "tls_server_name":
tlsServerName = strings.Trim(strVal, "\"")
tlsServerName = strVal
case "headers":
headersVal := obj.Get(val).Value
headersValInterface, err := ast.JSON(headersVal)
Expand Down