Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to Elasticsearch filter example to utilize the new SDK functionality #159

Closed
wants to merge 11 commits into from
923 changes: 910 additions & 13 deletions data_filter_elasticsearch/go.sum

Large diffs are not rendered by default.

46 changes: 35 additions & 11 deletions data_filter_elasticsearch/internal/api/api.go
Expand Up @@ -5,17 +5,19 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"

"github.com/gorilla/mux"
"github.com/olivere/elastic"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/es"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/opa"
"github.com/open-policy-agent/contrib/data_filter_elasticsearch/internal/mapper"
"github.com/open-policy-agent/opa/sdk"
)

const (
Expand All @@ -25,6 +27,8 @@ const (
apiCodeNotAuthorized = "not_authorized"
)

var opa *sdk.OPA

type apiError struct {
Error struct {
Code string `json:"code"`
Expand Down Expand Up @@ -57,6 +61,7 @@ func New(esClient *elastic.Client, index string) *ServerAPI {

// Run the server.
func (api *ServerAPI) Run(ctx context.Context) error {
opa = startOpa()
fmt.Println("Starting server 8080....")
return http.ListenAndServe(":8080", api.router)
}
Expand All @@ -73,7 +78,7 @@ func (api *ServerAPI) handlGetPosts(w http.ResponseWriter, r *http.Request) {
return
}

combinedQuery := combineQuery(es.GenerateMatchAllQuery(), result.Query)
combinedQuery := combineQuery(mapper.GenerateMatchAllQuery(), result.Query)
queryEs(r.Context(), api.es, api.index, combinedQuery, w)

}
Expand All @@ -91,11 +96,27 @@ func (api *ServerAPI) handleGetPost(w http.ResponseWriter, r *http.Request) {
}

vars := mux.Vars(r)
combinedQuery := combineQuery(es.GenerateTermQuery("id", vars["id"]), result.Query)
combinedQuery := combineQuery(mapper.GenerateTermQuery("id", vars["id"]), result.Query)
queryEs(r.Context(), api.es, api.index, combinedQuery, w)
}

func queryOPA(w http.ResponseWriter, r *http.Request) (opa.Result, error) {
func startOpa() *sdk.OPA {
config, err := os.ReadFile("opa-conf.yaml")
if err != nil {

panic(err)
}

opa, err := sdk.New(context.Background(), sdk.Options{
Config: bytes.NewReader(config),
})
if err != nil {
panic(err)
}
return opa
}

func queryOPA(w http.ResponseWriter, r *http.Request) (mapper.Result, error) {

user := r.Header.Get("Authorization")
path := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
Expand All @@ -106,20 +127,23 @@ func queryOPA(w http.ResponseWriter, r *http.Request) (opa.Result, error) {
"user": user,
}

// load policy
module, err := ioutil.ReadFile(opa.PolicyFileName)
decision, err := opa.Partial(r.Context(), sdk.PartialOptions{
Input: input,
Unknowns: []string{"data.elastic"},
Query: "data.example.allow == true",
Mapper: &mapper.ElasticMapper{},
})
if err != nil {
return opa.Result{}, fmt.Errorf("failed to read policy: %v", err)
return mapper.Result{}, err
}

return opa.Compile(r.Context(), input, module)
return decision.Result.(mapper.Result), nil
}

func combineQuery(queryFromHandler elastic.Query, queryFromOpa elastic.Query) elastic.Query {
var combinedQuery elastic.Query = queryFromHandler
if queryFromOpa != nil {
queries := []elastic.Query{queryFromOpa, queryFromHandler}
combinedQuery = es.GenerateBoolFilterQuery(queries)
combinedQuery = mapper.GenerateBoolFilterQuery(queries)
}
return combinedQuery
}
Expand Down
108 changes: 0 additions & 108 deletions data_filter_elasticsearch/internal/es/es.go
Expand Up @@ -7,7 +7,6 @@ package es
import (
"context"
"encoding/json"
"fmt"

"github.com/olivere/elastic"
)
Expand Down Expand Up @@ -180,90 +179,6 @@ func GetIndexMapping() string {
return mapping
}

// Elasticsearch queries

// GenerateTermQuery returns an ES Term Query.
func GenerateTermQuery(fieldName string, fieldValue interface{}) *elastic.TermQuery {
return elastic.NewTermQuery(fieldName, fieldValue).QueryName("TermQuery")

}

// GenerateNestedQuery returns an ES Nested Query.
func GenerateNestedQuery(path string, query elastic.Query) *elastic.NestedQuery {
return elastic.NewNestedQuery(path, query).QueryName("NestedQuery").IgnoreUnmapped(true)

}

// GenerateBoolFilterQuery returns an ES Filter Bool Query.
func GenerateBoolFilterQuery(filters []elastic.Query) *elastic.BoolQuery {
q := elastic.NewBoolQuery()
for _, filter := range filters {
q = q.Filter(filter)
}
q = q.QueryName("BoolFilterQuery")
return q

}

// GenerateBoolShouldQuery returns an ES Should Bool Query.
func GenerateBoolShouldQuery(queries []elastic.Query) *elastic.BoolQuery {
q := elastic.NewBoolQuery().QueryName("BoolShouldQuery")
for _, query := range queries {
q = q.Should(query)
}
return q
}

// GenerateBoolMustNotQuery returns an ES Must Not Bool Query.
func GenerateBoolMustNotQuery(fieldName string, fieldValue interface{}) *elastic.BoolQuery {
q := elastic.NewBoolQuery().QueryName("BoolMustNotQuery")
q = q.MustNot(elastic.NewTermQuery(fieldName, fieldValue))
return q
}

// GenerateMatchAllQuery returns an ES MatchAll Query.
func GenerateMatchAllQuery() *elastic.MatchAllQuery {
return elastic.NewMatchAllQuery().QueryName("MatchAllQuery")
}

// GenerateMatchQuery returns an ES Match Query.
func GenerateMatchQuery(fieldName string, fieldValue interface{}) *elastic.MatchQuery {
return elastic.NewMatchQuery(fieldName, fieldValue).QueryName("MatchQuery")
}

// GenerateQueryStringQuery returns an ES Query String Query.
func GenerateQueryStringQuery(fieldName string, fieldValue interface{}) *elastic.QueryStringQuery {
queryString := fmt.Sprintf("*%s*", fieldValue)
q := elastic.NewQueryStringQuery(queryString).QueryName("QueryStringQuery")
q = q.DefaultField(fieldName)
return q
}

// GenerateRegexpQuery returns an ES Regexp Query.
func GenerateRegexpQuery(fieldName string, fieldValue interface{}) *elastic.RegexpQuery {
return elastic.NewRegexpQuery(fieldName, fieldValue.(string))
}

// GenerateRangeQueryLt returns an ES Less Than Range Query.
func GenerateRangeQueryLt(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Lt(val)
}

// GenerateRangeQueryLte returns an ES Less Than or Equal Range Query.
func GenerateRangeQueryLte(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Lte(val)
}

// GenerateRangeQueryGt returns an ES Greater Than Range Query.
func GenerateRangeQueryGt(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Gt(val)
}

// GenerateRangeQueryGte returns an ES Greater Than or Equal Range Query.
func GenerateRangeQueryGte(fieldName string, val interface{}) *elastic.RangeQuery {
return elastic.NewRangeQuery(fieldName).Gte(val)
}

// ExecuteEsSearch executes ES query.
func ExecuteEsSearch(ctx context.Context, client *elastic.Client, indexName string, query elastic.Query) (*elastic.SearchResult, error) {
searchResult, err := client.Search().
Expand All @@ -277,29 +192,6 @@ func ExecuteEsSearch(ctx context.Context, client *elastic.Client, indexName stri
return searchResult, nil
}

func analyzeSearchResult(searchResult *elastic.SearchResult) {

if searchResult.Hits.TotalHits > 0 {
fmt.Printf("Found a total of %d posts\n", searchResult.Hits.TotalHits)

// Iterate through results
for _, hit := range searchResult.Hits.Hits {
// Deserialize hit
var t Post
err := json.Unmarshal(*hit.Source, &t)
if err != nil {
panic(err)
}

// Print with post
fmt.Printf("\nPost ID: %s\nAuthor: %s\nMessage: %s\nDepartment: %s\nClearance: %d\n", t.ID, t.Author, t.Message, t.Department, t.Clearance)
}
} else {
// No hits
fmt.Print("Found no posts\n")
}
}

// GetPrettyESResult returns formatted ES results.
func GetPrettyESResult(searchResult *elastic.SearchResult) []Post {

Expand Down