Skip to content

Commit

Permalink
Add support for contextual validation
Browse files Browse the repository at this point in the history
By passing the Gin context to bindings, custom validators can take
advantage of the information in the context.
  • Loading branch information
kszafran committed Sep 22, 2021
1 parent e3337ff commit 14b4deb
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 56 deletions.
45 changes: 40 additions & 5 deletions binding/binding.go
Expand Up @@ -7,7 +7,10 @@

package binding

import "net/http"
import (
"context"
"net/http"
)

// Content-Type MIME of the most common data formats.
const (
Expand All @@ -32,20 +35,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// CtxBinding enables contextual validation by adding CtxBind to Binding.
// Custom validators can take advantage of the information in the context.
type CtxBinding interface {
Binding
CtxBind(context.Context, *http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// CtxBindingBody enables contextual validation by adding CtxBindBody to BindingBody.
// Custom validators can take advantage of the information in the context.
type CtxBindingBody interface {
BindingBody
CtxBind(context.Context, *http.Request, interface{}) error
CtxBindBody(context.Context, []byte, interface{}) error
}

// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}

// CtxBindingUri enables contextual validation by adding CtxBindUri to BindingUri.
// Custom validators can take advantage of the information in the context.
type CtxBindingUri interface {
BindingUri
CtxBindUri(context.Context, map[string][]string, interface{}) error
}

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
Expand All @@ -64,6 +88,14 @@ type StructValidator interface {
Engine() interface{}
}

// CtxStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type CtxStructValidator interface {
StructValidator
ValidateStructCtx(context.Context, interface{}) error
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
Expand Down Expand Up @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding {
}
}

func validate(obj interface{}) error {
func validateCtx(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(CtxStructValidator); ok {
return v.ValidateStructCtx(ctx, obj)
}
return Validator.ValidateStruct(obj)
}
15 changes: 13 additions & 2 deletions binding/binding_msgpack_test.go
Expand Up @@ -9,6 +9,7 @@ package binding

import (
"bytes"
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) {
string(data), string(data[1:]))
}

func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testMsgPackBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body,
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = MsgPack.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = b.CtxBind(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand Down
43 changes: 34 additions & 9 deletions binding/binding_test.go
Expand Up @@ -6,6 +6,7 @@ package binding

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
Expand All @@ -20,6 +21,7 @@ import (
"time"

"github.com/gin-gonic/gin/testdata/protoexample"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
Expand All @@ -38,6 +40,10 @@ type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"`
}

type ConditionalFooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"`
}

type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"`
Expand Down Expand Up @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct {
PtrBar *map[string]interface{} `form:"ptr_bar"`
}

func init() {
_ = Validator.Engine().(*validator.Validate).RegisterValidationCtx(
"required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool {
if ctx.Value("condition") == true {
return !fl.Field().IsZero()
}
return true
})
}

func TestBindingDefault(t *testing.T) {
assert.Equal(t, Form, Default("GET", ""))
assert.Equal(t, Form, Default("GET", MIMEJSON))
Expand Down Expand Up @@ -1179,7 +1195,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody
assert.Error(t, err)
}

func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -1190,7 +1206,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
err = b.CtxBind(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand All @@ -1204,7 +1229,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba

var obj2 []FooStruct
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj2)
err = b.Bind(req, &obj2)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1249,7 +1274,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1267,7 +1292,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1285,7 +1310,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath

obj = FooStructDisallowUnknownFields{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
assert.Contains(t, err.Error(), "what")
}
Expand All @@ -1301,7 +1326,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1318,7 +1343,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1349,7 +1374,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down
26 changes: 19 additions & 7 deletions binding/form.go
Expand Up @@ -5,6 +5,7 @@
package binding

import (
"context"
"errors"
"net/http"
)
Expand All @@ -19,7 +20,11 @@ func (formBinding) Name() string {
return "form"
}

func (formBinding) Bind(req *http.Request, obj interface{}) error {
func (b formBinding) Bind(req *http.Request, obj interface{}) error {
return b.CtxBind(context.Background(), req, obj)
}

func (formBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
Expand All @@ -29,34 +34,41 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error {
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
return validateCtx(ctx, obj)
}

func (formPostBinding) Name() string {
return "form-urlencoded"
}

func (formPostBinding) Bind(req *http.Request, obj interface{}) error {
func (b formPostBinding) Bind(req *http.Request, obj interface{}) error {
return b.CtxBind(context.Background(), req, obj)
}

func (formPostBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
return validateCtx(ctx, obj)
}

func (formMultipartBinding) Name() string {
return "multipart/form-data"
}

func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
func (b formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
return b.CtxBind(context.Background(), req, obj)
}

func (formMultipartBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return err
}
if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil {
return err
}

return validate(obj)
return validateCtx(ctx, obj)
}
9 changes: 6 additions & 3 deletions binding/header.go
@@ -1,6 +1,7 @@
package binding

import (
"context"
"net/http"
"net/textproto"
"reflect"
Expand All @@ -12,13 +13,15 @@ func (headerBinding) Name() string {
return "header"
}

func (headerBinding) Bind(req *http.Request, obj interface{}) error {
func (b headerBinding) Bind(req *http.Request, obj interface{}) error {
return b.CtxBind(context.Background(), req, obj)
}

func (headerBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error {
if err := mapHeader(obj, req.Header); err != nil {
return err
}

return validate(obj)
return validateCtx(ctx, obj)
}

func mapHeader(ptr interface{}, h map[string][]string) error {
Expand Down
21 changes: 15 additions & 6 deletions binding/json.go
Expand Up @@ -6,6 +6,7 @@ package binding

import (
"bytes"
"context"
"errors"
"io"
"net/http"
Expand All @@ -30,18 +31,26 @@ func (jsonBinding) Name() string {
return "json"
}

func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
func (b jsonBinding) Bind(req *http.Request, obj interface{}) error {
return b.CtxBind(context.Background(), req, obj)
}

func (jsonBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error {
if req == nil || req.Body == nil {
return errors.New("invalid request")
}
return decodeJSON(req.Body, obj)
return decodeJSON(ctx, req.Body, obj)
}

func (b jsonBinding) BindBody(body []byte, obj interface{}) error {
return b.CtxBindBody(context.Background(), body, obj)
}

func (jsonBinding) BindBody(body []byte, obj interface{}) error {
return decodeJSON(bytes.NewReader(body), obj)
func (jsonBinding) CtxBindBody(ctx context.Context, body []byte, obj interface{}) error {
return decodeJSON(ctx, bytes.NewReader(body), obj)
}

func decodeJSON(r io.Reader, obj interface{}) error {
func decodeJSON(ctx context.Context, r io.Reader, obj interface{}) error {
decoder := json.NewDecoder(r)
if EnableDecoderUseNumber {
decoder.UseNumber()
Expand All @@ -52,5 +61,5 @@ func decodeJSON(r io.Reader, obj interface{}) error {
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
return validateCtx(ctx, obj)
}

0 comments on commit 14b4deb

Please sign in to comment.