Skip to content

Commit

Permalink
Add support for contextual validation
Browse files Browse the repository at this point in the history
By passing the Gin context to bindings, custom validators can take
advantage of the information in the context.
  • Loading branch information
kszafran committed Nov 2, 2021
1 parent 62285c8 commit ca15e39
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 71 deletions.
45 changes: 40 additions & 5 deletions binding/binding.go
Expand Up @@ -7,7 +7,10 @@

package binding

import "net/http"
import (
"context"
"net/http"
)

// Content-Type MIME of the most common data formats.
const (
Expand All @@ -32,20 +35,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// ContextBinding enables contextual validation by adding BindContext to Binding.
// Custom validators can take advantage of the information in the context.
type ContextBinding interface {
Binding
BindContext(context.Context, *http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody.
// Custom validators can take advantage of the information in the context.
type ContextBindingBody interface {
BindingBody
BindContext(context.Context, *http.Request, interface{}) error
BindBodyContext(context.Context, []byte, interface{}) error
}

// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}

// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri.
// Custom validators can take advantage of the information in the context.
type ContextBindingUri interface {
BindingUri
BindUriContext(context.Context, map[string][]string, interface{}) error
}

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
Expand All @@ -64,6 +88,14 @@ type StructValidator interface {
Engine() interface{}
}

// ContextStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type ContextStructValidator interface {
StructValidator
ValidateStructContext(context.Context, interface{}) error
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
Expand Down Expand Up @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding {
}
}

func validate(obj interface{}) error {
func validateContext(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(ContextStructValidator); ok {
return v.ValidateStructContext(ctx, obj)
}
return Validator.ValidateStruct(obj)
}
15 changes: 13 additions & 2 deletions binding/binding_msgpack_test.go
Expand Up @@ -9,6 +9,7 @@ package binding

import (
"bytes"
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) {
string(data), string(data[1:]))
}

func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testMsgPackBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body,
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = MsgPack.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
req.Header.Add("Content-Type", MIMEMSGPACK)
err = b.BindContext(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand Down
47 changes: 41 additions & 6 deletions binding/binding_nomsgpack.go
Expand Up @@ -7,7 +7,10 @@

package binding

import "net/http"
import (
"context"
"net/http"
)

// Content-Type MIME of the most common data formats.
const (
Expand All @@ -30,20 +33,41 @@ type Binding interface {
Bind(*http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// ContextBinding enables contextual validation by adding BindContext to Binding.
// Custom validators can take advantage of the information in the context.
type ContextBinding interface {
Binding
BindContext(context.Context, *http.Request, interface{}) error
}

// BindingBody adds BindBody method to Binding. BindBody is similar to Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, interface{}) error
}

// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it read the Params.
// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody.
// Custom validators can take advantage of the information in the context.
type ContextBindingBody interface {
BindingBody
BindContext(context.Context, *http.Request, interface{}) error
BindBodyContext(context.Context, []byte, interface{}) error
}

// BindingUri is similar to Bind, but it read the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, interface{}) error
}

// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri.
// Custom validators can take advantage of the information in the context.
type ContextBindingUri interface {
BindingUri
BindUriContext(context.Context, map[string][]string, interface{}) error
}

// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
Expand All @@ -62,6 +86,14 @@ type StructValidator interface {
Engine() interface{}
}

// ContextStructValidator is an extension of StructValidator that requires implementing
// context-aware validation.
// Custom validators can take advantage of the information in the context.
type ContextStructValidator interface {
StructValidator
ValidateStructContext(context.Context, interface{}) error
}

// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
Expand All @@ -85,7 +117,7 @@ var (
// Default returns the appropriate Binding instance based on the HTTP method
// and the content type.
func Default(method, contentType string) Binding {
if method == "GET" {
if method == http.MethodGet {
return Form
}

Expand All @@ -105,9 +137,12 @@ func Default(method, contentType string) Binding {
}
}

func validate(obj interface{}) error {
func validateContext(ctx context.Context, obj interface{}) error {
if Validator == nil {
return nil
}
if v, ok := Validator.(ContextStructValidator); ok {
return v.ValidateStructContext(ctx, obj)
}
return Validator.ValidateStruct(obj)
}
75 changes: 66 additions & 9 deletions binding/binding_test.go
Expand Up @@ -6,6 +6,7 @@ package binding

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
Expand All @@ -20,6 +21,7 @@ import (
"time"

"github.com/gin-gonic/gin/testdata/protoexample"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
Expand All @@ -38,6 +40,10 @@ type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"`
}

type ConditionalFooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"`
}

type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"`
Expand Down Expand Up @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct {
PtrBar *map[string]interface{} `form:"ptr_bar"`
}

func init() {
_ = Validator.Engine().(*validator.Validate).RegisterValidationCtx(
"required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool {
if ctx.Value("condition") == true {
return !fl.Field().IsZero()
}
return true
})
}

func TestBindingDefault(t *testing.T) {
assert.Equal(t, Form, Default("GET", ""))
assert.Equal(t, Form, Default("GET", MIMEJSON))
Expand Down Expand Up @@ -796,6 +812,38 @@ func TestUriBinding(t *testing.T) {
assert.Equal(t, map[string]interface{}(nil), not.Name)
}

func TestUriBindingWithContext(t *testing.T) {
b := Uri

type Tag struct {
Name string `uri:"name" binding:"required_if_condition"`
}

empty := make(map[string][]string)
assert.NoError(t, b.BindUriContext(context.Background(), empty, new(Tag)))
assert.Error(t, b.BindUriContext(context.WithValue(context.Background(), "condition", true), empty, new(Tag))) // nolint
}

func TestUriBindingWithNonCtxValidator(t *testing.T) {
prev := Validator
defer func() {
Validator = prev
}()
Validator = &nonCtxValidator{}

TestUriBinding(t)
}

type nonCtxValidator defaultValidator

func (v *nonCtxValidator) ValidateStruct(obj interface{}) error {
return (*defaultValidator)(v).ValidateStruct(obj)
}

func (v *nonCtxValidator) Engine() interface{} {
return (*defaultValidator)(v).Engine()
}

func TestUriInnerBinding(t *testing.T) {
type Tag struct {
Name string `uri:"name"`
Expand Down Expand Up @@ -1179,7 +1227,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody
assert.Error(t, err)
}

func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
func testBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())

obj := FooStruct{}
Expand All @@ -1190,7 +1238,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)

obj2 := ConditionalFooStruct{}
req = requestWithBody("POST", path, body)
err = b.BindContext(context.Background(), req, &obj2)
assert.NoError(t, err)
assert.Equal(t, "bar", obj2.Foo)

err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint
assert.Error(t, err)
}

Expand All @@ -1204,7 +1261,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba

var obj2 []FooStruct
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj2)
err = b.Bind(req, &obj2)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1249,7 +1306,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1267,7 +1324,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod

obj = FooStructUseNumber{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1285,7 +1342,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath

obj = FooStructDisallowUnknownFields{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
assert.Contains(t, err.Error(), "what")
}
Expand All @@ -1301,7 +1358,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad

obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand All @@ -1318,7 +1375,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down Expand Up @@ -1349,7 +1406,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body
obj = protoexample.Test{}
req = requestWithBody("POST", badPath, badBody)
req.Header.Add("Content-Type", MIMEPROTOBUF)
err = ProtoBuf.Bind(req, &obj)
err = b.Bind(req, &obj)
assert.Error(t, err)
}

Expand Down

0 comments on commit ca15e39

Please sign in to comment.