Skip to content

Commit

Permalink
server: Remove unnecessary AST-to-JSON conversions.
Browse files Browse the repository at this point in the history
This time for v0QueryPath, v1DataGet, and v1DataPost.

Signed-off-by: Teemu Koponen <koponen@styra.com>
  • Loading branch information
koponen-styra authored and ashutosh-narkar committed Apr 3, 2024
1 parent 88eaaa9 commit 8fde826
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 54 deletions.
80 changes: 28 additions & 52 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1028,22 +1028,12 @@ func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath str
ctx := logging.WithDecisionID(r.Context(), decisionID)
annotateSpan(ctx, decisionID)

input, err := readInputV0(r)
input, goInput, err := readInputV0(r)
if err != nil {
writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, fmt.Errorf("unexpected parse error for input: %w", err))
return
}

var goInput *interface{}
if input != nil {
x, err := ast.JSON(input)
if err != nil {
writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err))
return
}
goInput = &x
}

// Prepare for query.
txn, err := s.store.NewTransaction(ctx)
if err != nil {
Expand Down Expand Up @@ -1446,26 +1436,17 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) {
inputs := r.URL.Query()[types.ParamInputV1]

var input ast.Value
var goInput *interface{}

if len(inputs) > 0 {
var err error
input, err = readInputGetV1(inputs[len(inputs)-1])
input, goInput, err = readInputGetV1(inputs[len(inputs)-1])
if err != nil {
writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err)
return
}
}

var goInput *interface{}
if input != nil {
x, err := ast.JSON(input)
if err != nil {
writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err))
return
}
goInput = &x
}

m.Timer(metrics.RegoInputParse).Stop()

// Prepare for query.
Expand Down Expand Up @@ -1678,22 +1659,12 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) {

m.Timer(metrics.RegoInputParse).Start()

input, err := readInputPostV1(r)
input, goInput, err := readInputPostV1(r)
if err != nil {
writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err)
return
}

var goInput *interface{}
if input != nil {
x, err := ast.JSON(input)
if err != nil {
writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err))
return
}
goInput = &x
}

m.Timer(metrics.RegoInputParse).Stop()

txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)})
Expand Down Expand Up @@ -2752,67 +2723,71 @@ func getExplain(p []string, zero types.ExplainModeV1) types.ExplainModeV1 {
return zero
}

func readInputV0(r *http.Request) (ast.Value, error) {
func readInputV0(r *http.Request) (ast.Value, *interface{}, error) {

parsed, ok := authorizer.GetBodyOnContext(r.Context())
if ok {
return ast.InterfaceToValue(parsed)
v, err := ast.InterfaceToValue(parsed)
return v, &parsed, err
}

// decompress the input if sent as zip
body, err := readPlainBody(r)
if err != nil {
return nil, fmt.Errorf("could not decompress the body: %w", err)
return nil, nil, fmt.Errorf("could not decompress the body: %w", err)
}

var x interface{}

if strings.Contains(r.Header.Get("Content-Type"), "yaml") {
bs, err := io.ReadAll(body)
if err != nil {
return nil, err
return nil, nil, err
}
if len(bs) > 0 {
if err = util.Unmarshal(bs, &x); err != nil {
return nil, fmt.Errorf("body contains malformed input document: %w", err)
return nil, nil, fmt.Errorf("body contains malformed input document: %w", err)
}
}
} else {
dec := util.NewJSONDecoder(body)
if err := dec.Decode(&x); err != nil && err != io.EOF {
return nil, fmt.Errorf("body contains malformed input document: %w", err)
return nil, nil, fmt.Errorf("body contains malformed input document: %w", err)
}
}

return ast.InterfaceToValue(x)
v, err := ast.InterfaceToValue(x)
return v, &x, err
}

func readInputGetV1(str string) (ast.Value, error) {
func readInputGetV1(str string) (ast.Value, *interface{}, error) {
var input interface{}
if err := util.UnmarshalJSON([]byte(str), &input); err != nil {
return nil, fmt.Errorf("parameter contains malformed input document: %w", err)
return nil, nil, fmt.Errorf("parameter contains malformed input document: %w", err)
}
return ast.InterfaceToValue(input)
v, err := ast.InterfaceToValue(input)
return v, &input, err
}

func readInputPostV1(r *http.Request) (ast.Value, error) {
func readInputPostV1(r *http.Request) (ast.Value, *interface{}, error) {

parsed, ok := authorizer.GetBodyOnContext(r.Context())
if ok {
if obj, ok := parsed.(map[string]interface{}); ok {
if input, ok := obj["input"]; ok {
return ast.InterfaceToValue(input)
v, err := ast.InterfaceToValue(input)
return v, &input, err
}
}
return nil, nil
return nil, nil, nil
}

var request types.DataRequestV1

// decompress the input if sent as zip
body, err := readPlainBody(r)
if err != nil {
return nil, fmt.Errorf("could not decompress the body: %w", err)
return nil, nil, fmt.Errorf("could not decompress the body: %w", err)
}

ct := r.Header.Get("Content-Type")
Expand All @@ -2821,25 +2796,26 @@ func readInputPostV1(r *http.Request) (ast.Value, error) {
if strings.Contains(ct, "yaml") {
bs, err := io.ReadAll(body)
if err != nil {
return nil, err
return nil, nil, err
}
if len(bs) > 0 {
if err = util.Unmarshal(bs, &request); err != nil {
return nil, fmt.Errorf("body contains malformed input document: %w", err)
return nil, nil, fmt.Errorf("body contains malformed input document: %w", err)
}
}
} else {
dec := util.NewJSONDecoder(body)
if err := dec.Decode(&request); err != nil && err != io.EOF {
return nil, fmt.Errorf("body contains malformed input document: %w", err)
return nil, nil, fmt.Errorf("body contains malformed input document: %w", err)
}
}

if request.Input == nil {
return nil, nil
return nil, nil, nil
}

return ast.InterfaceToValue(*request.Input)
v, err := ast.InterfaceToValue(*request.Input)
return v, request.Input, err
}

type compileRequest struct {
Expand Down
12 changes: 10 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4119,7 +4119,7 @@ func TestServerUsesAuthorizerParsedBody(t *testing.T) {
})

// Check that v1 reader function behaves correctly.
inp, err := readInputPostV1(req.WithContext(ctx))
inp, goInp, err := readInputPostV1(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}
Expand All @@ -4130,19 +4130,27 @@ func TestServerUsesAuthorizerParsedBody(t *testing.T) {
t.Fatalf("expected %v but got %v", exp, inp)
}

if exp.Value.Compare(ast.MustInterfaceToValue(*goInp)) != 0 {
t.Fatalf("expected %v but got %v", exp, *goInp)
}

// Check that v0 reader function behaves correctly.
ctx = authorizer.SetBodyOnContext(req.Context(), map[string]interface{}{
"foo": "good",
})

inp, err = readInputV0(req.WithContext(ctx))
inp, goInp, err = readInputV0(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
}

if exp.Value.Compare(inp) != 0 {
t.Fatalf("expected %v but got %v", exp, inp)
}

if exp.Value.Compare(ast.MustInterfaceToValue(*goInp)) != 0 {
t.Fatalf("expected %v but got %v", exp, *goInp)
}
}

func TestServerReloadTrigger(t *testing.T) {
Expand Down

0 comments on commit 8fde826

Please sign in to comment.