Skip to content

Commit

Permalink
changes to utilize the added sdk functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
kroekle committed Dec 23, 2021
1 parent b07ce50 commit d830203
Show file tree
Hide file tree
Showing 6 changed files with 1,191 additions and 738 deletions.
923 changes: 910 additions & 13 deletions data_filter_elasticsearch/go.sum

Large diffs are not rendered by default.

46 changes: 35 additions & 11 deletions data_filter_elasticsearch/internal/api/api.go
Expand Up @@ -5,17 +5,19 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"

"github.com/gorilla/mux"
"github.com/olivere/elastic"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/es"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/opa"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/resolvers"
"github.com/open-policy-agent/opa/sdk"
)

const (
Expand All @@ -25,6 +27,8 @@ const (
apiCodeNotAuthorized = "not_authorized"
)

var opa *sdk.OPA

type apiError struct {
Error struct {
Code string `json:"code"`
Expand Down Expand Up @@ -57,6 +61,7 @@ func New(esClient *elastic.Client, index string) *ServerAPI {

// Run the server.
func (api *ServerAPI) Run(ctx context.Context) error {
opa = startOpa()
fmt.Println("Starting server 8080....")
return http.ListenAndServe(":8080", api.router)
}
Expand All @@ -73,7 +78,7 @@ func (api *ServerAPI) handlGetPosts(w http.ResponseWriter, r *http.Request) {
return
}

combinedQuery := combineQuery(es.GenerateMatchAllQuery(), result.Query)
combinedQuery := combineQuery(resolvers.GenerateMatchAllQuery(), result.Query)
queryEs(r.Context(), api.es, api.index, combinedQuery, w)

}
Expand All @@ -91,11 +96,26 @@ func (api *ServerAPI) handleGetPost(w http.ResponseWriter, r *http.Request) {
}

vars := mux.Vars(r)
combinedQuery := combineQuery(es.GenerateTermQuery("id", vars["id"]), result.Query)
combinedQuery := combineQuery(resolvers.GenerateTermQuery("id", vars["id"]), result.Query)
queryEs(r.Context(), api.es, api.index, combinedQuery, w)
}

func queryOPA(w http.ResponseWriter, r *http.Request) (opa.Result, error) {
func startOpa() *sdk.OPA {
config, err := os.ReadFile("opa-conf.yaml")
if err != nil {

panic(err)
}
opa, err := sdk.New(context.Background(), sdk.Options{
Config: bytes.NewReader(config),
})
if err != nil {
panic(err)
}
return opa
}

func queryOPA(w http.ResponseWriter, r *http.Request) (resolvers.Result, error) {

user := r.Header.Get("Authorization")
path := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
Expand All @@ -106,20 +126,24 @@ func queryOPA(w http.ResponseWriter, r *http.Request) (opa.Result, error) {
"user": user,
}

// load policy
module, err := ioutil.ReadFile(opa.PolicyFileName)
decision, err := opa.Partial(r.Context(), sdk.PartialOptions{
Input: input,
Unknowns: []string{"data.elastic"},
Path: "example/allow",
Query: "data.example.allow == true",
Resolver: &resolvers.ElasticResolver{},
})
if err != nil {
return opa.Result{}, fmt.Errorf("failed to read policy: %v", err)
return resolvers.Result{}, err
}

return opa.Compile(r.Context(), input, module)
return decision.Result.(resolvers.Result), nil
}

func combineQuery(queryFromHandler elastic.Query, queryFromOpa elastic.Query) elastic.Query {
var combinedQuery elastic.Query = queryFromHandler
if queryFromOpa != nil {
queries := []elastic.Query{queryFromOpa, queryFromHandler}
combinedQuery = es.GenerateBoolFilterQuery(queries)
combinedQuery = resolvers.GenerateBoolFilterQuery(queries)
}
return combinedQuery
}
Expand Down
108 changes: 0 additions & 108 deletions data_filter_elasticsearch/internal/es/es.go
Expand Up @@ -7,7 +7,6 @@ package es
import (
"context"
"encoding/json"
"fmt"

"github.com/olivere/elastic"
)
Expand Down Expand Up @@ -180,90 +179,6 @@ func GetIndexMapping() string {
return mapping
}

// Elasticsearch queries

// GenerateTermQuery returns an ES Term Query.
func GenerateTermQuery(fieldName string, fieldValue interface{}) *elastic.TermQuery {
return elastic.NewTermQuery(fieldName, fieldValue).QueryName("TermQuery")

}

// GenerateNestedQuery returns an ES Nested Query.
func GenerateNestedQuery(path string, query elastic.Query) *elastic.NestedQuery {
return elastic.NewNestedQuery(path, query).QueryName("NestedQuery").IgnoreUnmapped(true)

}

// GenerateBoolFilterQuery returns an ES Filter Bool Query.
func GenerateBoolFilterQuery(filters []elastic.Query) *elastic.BoolQuery {
q := elastic.NewBoolQuery()
for _, filter := range filters {
q = q.Filter(filter)
}
q = q.QueryName("BoolFilterQuery")
return q

}

// GenerateBoolShouldQuery returns an ES Should Bool Query.
func GenerateBoolShouldQuery(queries []elastic.Query) *elastic.BoolQuery {
q := elastic.NewBoolQuery().QueryName("BoolShouldQuery")
for _, query := range queries {
q = q.Should(query)
}
return q
}

// GenerateBoolMustNotQuery returns an ES Must Not Bool Query.
func GenerateBoolMustNotQuery(fieldName string, fieldValue interface{}) *elastic.BoolQuery {
q := elastic.NewBoolQuery().QueryName("BoolMustNotQuery")
q = q.MustNot(elastic.NewTermQuery(fieldName, fieldValue))
return q
}

// GenerateMatchAllQuery returns an ES MatchAll Query.
func GenerateMatchAllQuery() *elastic.MatchAllQuery {
return elastic.NewMatchAllQuery().QueryName("MatchAllQuery")
}

// GenerateMatchQuery returns an ES Match Query.
func GenerateMatchQuery(fieldName string, fieldValue interface{}) *elastic.MatchQuery {
return elastic.NewMatchQuery(fieldName, fieldValue).QueryName("MatchQuery")
}

// GenerateQueryStringQuery returns an ES Query String Query.
func GenerateQueryStringQuery(fieldName string, fieldValue interface{}) *elastic.QueryStringQuery {
queryString := fmt.Sprintf("*%s*", fieldValue)
q := elastic.NewQueryStringQuery(queryString).QueryName("QueryStringQuery")
q = q.DefaultField(fieldName)
return q
}

// GenerateRegexpQuery returns an ES Regexp Query.
func GenerateRegexpQuery(fieldName string, fieldValue interface{}) *elastic.RegexpQuery {
return elastic.NewRegexpQuery(fieldName, fieldValue.(string))
}

// GenerateRangeQueryLt returns an ES Less Than Range Query.
func GenerateRangeQueryLt(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Lt(val)
}

// GenerateRangeQueryLte returns an ES Less Than or Equal Range Query.
func GenerateRangeQueryLte(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Lte(val)
}

// GenerateRangeQueryGt returns an ES Greater Than Range Query.
func GenerateRangeQueryGt(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Gt(val)
}

// GenerateRangeQueryGte returns an ES Greater Than or Equal Range Query.
func GenerateRangeQueryGte(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Gte(val)
}

// ExecuteEsSearch executes ES query.
func ExecuteEsSearch(ctx context.Context, client *elastic.Client, indexName string, query elastic.Query) (*elastic.SearchResult, error) {
searchResult, err := client.Search().
Expand All @@ -277,29 +192,6 @@ func ExecuteEsSearch(ctx context.Context, client *elastic.Client, indexName stri
return searchResult, nil
}

func analyzeSearchResult(searchResult *elastic.SearchResult) {

if searchResult.Hits.TotalHits > 0 {
fmt.Printf("Found a total of %d posts\n", searchResult.Hits.TotalHits)

// Iterate through results
for _, hit := range searchResult.Hits.Hits {
// Deserialize hit
var t Post
err := json.Unmarshal(*hit.Source, &t)
if err != nil {
panic(err)
}

// Print with post
fmt.Printf("\nPost ID: %s\nAuthor: %s\nMessage: %s\nDepartment: %s\nClearance: %d\n", t.ID, t.Author, t.Message, t.Department, t.Clearance)
}
} else {
// No hits
fmt.Print("Found no posts\n")
}
}

// GetPrettyESResult returns formatted ES results.
func GetPrettyESResult(searchResult *elastic.SearchResult) []Post {

Expand Down

0 comments on commit d830203

Please sign in to comment.