Skip to content

Commit

Permalink
Fixed changes requested by @srenatus (squash before merge)
Browse files Browse the repository at this point in the history
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Dec 17, 2021
1 parent de8875c commit 9884270
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 61 deletions.
29 changes: 16 additions & 13 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ type Compiler struct {
// with the key being the generated name and value being the original.
RewrittenVars map[Var]Var

// Capabilities is the user-supplied capabilities or features allowed for OPA.
Capabilities *Capabilities

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
Expand All @@ -102,9 +99,10 @@ type Compiler struct {
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use Capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use Capabilities)
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
Expand Down Expand Up @@ -327,17 +325,22 @@ func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
return c
}

// WithCapabilities sets Capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, Capabilities
// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, capabilities
// may be able to restrict access to other language features. Capabilities allow callers to check
// if policies are compatible with a particular version of OPA. If policies are a compiled for a
// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them
// successfully.
func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
c.Capabilities = capabilities
c.capabilities = capabilities
return c
}

// Capabilities returns the capabilities enabled during compilation.
func (c *Compiler) Capabilities() *Capabilities {
return c.capabilities
}

// WithDebug sets where debug messages are written to. Passing `nil` has no
// effect.
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
Expand Down Expand Up @@ -1205,13 +1208,13 @@ func (c *Compiler) init() {
return
}

if c.Capabilities == nil {
c.Capabilities = CapabilitiesForThisVersion()
if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
}

c.builtins = make(map[string]*Builtin, len(c.Capabilities.Builtins)+len(c.customBuiltins))
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))

for _, bi := range c.Capabilities.Builtins {
for _, bi := range c.capabilities.Builtins {
c.builtins[bi.Name] = bi
}

Expand All @@ -1222,7 +1225,7 @@ func (c *Compiler) init() {
// Load the global input schema if one was provided.
if c.schemaSet != nil {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
tpe, err := loadSchema(schema, c.Capabilities.AllowNet)
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
} else {
Expand Down
2 changes: 1 addition & 1 deletion rego/rego_wasmtarget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func TestEvalWasmWithHTTPAllowNet(t *testing.T) {

serverUrl, err := url.Parse(ts.URL)
if err != nil {
panic(err)
t.Fatal(err)
}
serverHost := strings.Split(serverUrl.Host, ":")[0]

Expand Down
2 changes: 1 addition & 1 deletion topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {

var capabilities *ast.Capabilities
if e.compiler != nil {
capabilities = e.compiler.Capabilities
capabilities = e.compiler.Capabilities()
}

bctx := BuiltinContext{
Expand Down
57 changes: 12 additions & 45 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ func TestHTTPGetRequest(t *testing.T) {
"test-header": []interface{}{"test-value"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

// run the test
tests := []struct {
Expand Down Expand Up @@ -130,10 +127,7 @@ func TestHTTPGetRequestTlsInsecureSkipVerify(t *testing.T) {
"content-type": []interface{}{"text/plain; charset=utf-8"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

// run the test
tests := []struct {
Expand Down Expand Up @@ -181,10 +175,7 @@ func TestHTTPEnableJSONDecode(t *testing.T) {
"content-type": []interface{}{"text/plain; charset=utf-8"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

// run the test
tests := []struct {
Expand Down Expand Up @@ -462,10 +453,7 @@ func TestHTTPDeleteRequest(t *testing.T) {
"content-type": []interface{}{"application/json"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

// delete a new person
personToDelete := Person{ID: "2", Firstname: "Joe"}
Expand Down Expand Up @@ -622,10 +610,7 @@ func TestHTTPRedirectDisable(t *testing.T) {
"location": []interface{}{"/test"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

data := loadSmallTestData()
rules := append(
Expand Down Expand Up @@ -655,10 +640,7 @@ func TestHTTPRedirectEnable(t *testing.T) {
"content-length": []interface{}{"0"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

data := loadSmallTestData()
rules := append(
Expand All @@ -680,28 +662,19 @@ func TestHTTPSendRaiseError(t *testing.T) {
networkErrObj["code"] = HTTPSendNetworkErr
networkErrObj["message"] = "Get \"foo://foo.com\": unsupported protocol scheme \"foo\""

networkErr, err := ast.InterfaceToValue(networkErrObj)
if err != nil {
panic(err)
}
networkErr := ast.MustInterfaceToValue(networkErrObj)

internalErrObj := make(map[string]interface{})
internalErrObj["code"] = HTTPSendInternalErr
internalErrObj["message"] = fmt.Sprintf(`http.send({"method": "get", "url": "%s", "force_json_decode": true, "raise_error": false, "force_cache": true}): eval_builtin_error: http.send: 'force_cache' set but 'force_cache_duration_seconds' parameter is missing`, baseURL)

internalErr, err := ast.InterfaceToValue(internalErrObj)
if err != nil {
panic(err)
}
internalErr := ast.MustInterfaceToValue(internalErrObj)

responseObj := make(map[string]interface{})
responseObj["status_code"] = 0
responseObj["error"] = internalErrObj

response, err := ast.InterfaceToValue(responseObj)
if err != nil {
panic(err)
}
response := ast.MustInterfaceToValue(responseObj)

tests := []struct {
note string
Expand Down Expand Up @@ -2601,10 +2574,7 @@ func TestSocketHTTPGetRequest(t *testing.T) {
"test-header": []interface{}{"test-value"},
}

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

// run the test
tests := []struct {
Expand Down Expand Up @@ -2698,7 +2668,7 @@ func TestHTTPGetRequestAllowNet(t *testing.T) {
// host
serverURL, err := url.Parse(ts.URL)
if err != nil {
panic(err)
t.Fatal(err)
}
serverHost := strings.Split(serverURL.Host, ":")[0]

Expand All @@ -2710,10 +2680,7 @@ func TestHTTPGetRequestAllowNet(t *testing.T) {
expectedResult["body"] = body
expectedResult["raw_body"] = "{\"ok\":true}\n"

resultObj, err := ast.InterfaceToValue(expectedResult)
if err != nil {
panic(err)
}
resultObj := ast.MustInterfaceToValue(expectedResult)

expectedError := &Error{Code: "eval_builtin_error", Message: fmt.Sprintf("http.send: unallowed host: %s", serverHost)}

Expand Down
2 changes: 1 addition & 1 deletion topdown/topdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ func setTime(t time.Time) func(*Query) *Query {

func setAllowNet(a []string) func(*Query) *Query {
return func(q *Query) *Query {
c := q.compiler.Capabilities
c := q.compiler.Capabilities()
c.AllowNet = a
return q.WithCompiler(q.compiler.WithCapabilities(c))
}
Expand Down

0 comments on commit 9884270

Please sign in to comment.