From 14b4debd27e0f8bc81511c1272ce5d84e86adfec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 22 Sep 2021 11:17:30 +0200 Subject: [PATCH] Add support for contextual validation By passing the Gin context to bindings, custom validators can take advantage of the information in the context. --- binding/binding.go | 45 +++++++++++++++++++++++++++++---- binding/binding_msgpack_test.go | 15 +++++++++-- binding/binding_test.go | 43 ++++++++++++++++++++++++------- binding/form.go | 26 ++++++++++++++----- binding/header.go | 9 ++++--- binding/json.go | 21 ++++++++++----- binding/msgpack.go | 21 ++++++++++----- binding/query.go | 13 +++++++--- binding/uri.go | 10 ++++++-- binding/validate_test.go | 5 ++++ binding/xml.go | 22 +++++++++++----- binding/yaml.go | 21 ++++++++++----- context.go | 2 +- 13 files changed, 197 insertions(+), 56 deletions(-) diff --git a/binding/binding.go b/binding/binding.go index 7042101d5c..81fea59651 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -7,7 +7,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) // Content-Type MIME of the most common data formats. const ( @@ -32,20 +35,41 @@ type Binding interface { Bind(*http.Request, interface{}) error } -// BindingBody adds BindBody method to Binding. BindBody is similar with Bind, +// CtxBinding enables contextual validation by adding CtxBind to Binding. +// Custom validators can take advantage of the information in the context. +type CtxBinding interface { + Binding + CtxBind(context.Context, *http.Request, interface{}) error +} + +// BindingBody adds BindBody method to Binding. BindBody is similar to Bind, // but it reads the body from supplied bytes instead of req.Body. type BindingBody interface { Binding BindBody([]byte, interface{}) error } -// BindingUri adds BindUri method to Binding. BindUri is similar with Bind, -// but it read the Params. +// CtxBindingBody enables contextual validation by adding CtxBindBody to BindingBody. +// Custom validators can take advantage of the information in the context. +type CtxBindingBody interface { + BindingBody + CtxBind(context.Context, *http.Request, interface{}) error + CtxBindBody(context.Context, []byte, interface{}) error +} + +// BindingUri is similar to Bind, but it read the Params. type BindingUri interface { Name() string BindUri(map[string][]string, interface{}) error } +// CtxBindingUri enables contextual validation by adding CtxBindUri to BindingUri. +// Custom validators can take advantage of the information in the context. +type CtxBindingUri interface { + BindingUri + CtxBindUri(context.Context, map[string][]string, interface{}) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using @@ -64,6 +88,14 @@ type StructValidator interface { Engine() interface{} } +// CtxStructValidator is an extension of StructValidator that requires implementing +// context-aware validation. +// Custom validators can take advantage of the information in the context. +type CtxStructValidator interface { + StructValidator + ValidateStructCtx(context.Context, interface{}) error +} + // Validator is the default validator which implements the StructValidator // interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 // under the hood. @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding { } } -func validate(obj interface{}) error { +func validateCtx(ctx context.Context, obj interface{}) error { if Validator == nil { return nil } + if v, ok := Validator.(CtxStructValidator); ok { + return v.ValidateStructCtx(ctx, obj) + } return Validator.ValidateStruct(obj) } diff --git a/binding/binding_msgpack_test.go b/binding/binding_msgpack_test.go index 04d9407971..ea5958973f 100644 --- a/binding/binding_msgpack_test.go +++ b/binding/binding_msgpack_test.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) { string(data), string(data[1:])) } -func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testMsgPackBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEMSGPACK) - err = MsgPack.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + req.Header.Add("Content-Type", MIMEMSGPACK) + err = b.CtxBind(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } diff --git a/binding/binding_test.go b/binding/binding_test.go index 5b0ce39d3e..74b6ce4b24 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -20,6 +21,7 @@ import ( "time" "github.com/gin-gonic/gin/testdata/protoexample" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) @@ -38,6 +40,10 @@ type FooStruct struct { Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"` } +type ConditionalFooStruct struct { + Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"` +} + type FooBarStruct struct { FooStruct Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"` @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct { PtrBar *map[string]interface{} `form:"ptr_bar"` } +func init() { + _ = Validator.Engine().(*validator.Validate).RegisterValidationCtx( + "required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool { + if ctx.Value("condition") == true { + return !fl.Field().IsZero() + } + return true + }) +} + func TestBindingDefault(t *testing.T) { assert.Equal(t, Form, Default("GET", "")) assert.Equal(t, Form, Default("GET", MIMEJSON)) @@ -1179,7 +1195,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody assert.Error(t, err) } -func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testBodyBinding(t *testing.T, b CtxBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -1190,7 +1206,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + err = b.CtxBind(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.CtxBind(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } @@ -1204,7 +1229,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba var obj2 []FooStruct req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj2) + err = b.Bind(req, &obj2) assert.Error(t, err) } @@ -1249,7 +1274,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1267,7 +1292,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1285,7 +1310,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath obj = FooStructDisallowUnknownFields{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) assert.Contains(t, err.Error(), "what") } @@ -1301,7 +1326,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1318,7 +1343,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1349,7 +1374,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } diff --git a/binding/form.go b/binding/form.go index fa2a6540a0..e1cb60af1e 100644 --- a/binding/form.go +++ b/binding/form.go @@ -5,6 +5,7 @@ package binding import ( + "context" "errors" "net/http" ) @@ -19,7 +20,11 @@ func (formBinding) Name() string { return "form" } -func (formBinding) Bind(req *http.Request, obj interface{}) error { +func (b formBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} + +func (formBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } @@ -29,34 +34,41 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error { if err := mapForm(obj, req.Form); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } func (formPostBinding) Name() string { return "form-urlencoded" } -func (formPostBinding) Bind(req *http.Request, obj interface{}) error { +func (b formPostBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} + +func (formPostBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } if err := mapForm(obj, req.PostForm); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } func (formMultipartBinding) Name() string { return "multipart/form-data" } -func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { +func (b formMultipartBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} + +func (formMultipartBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseMultipartForm(defaultMemory); err != nil { return err } if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil { return err } - - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/header.go b/binding/header.go index b99302af82..1270faef6b 100644 --- a/binding/header.go +++ b/binding/header.go @@ -1,6 +1,7 @@ package binding import ( + "context" "net/http" "net/textproto" "reflect" @@ -12,13 +13,15 @@ func (headerBinding) Name() string { return "header" } -func (headerBinding) Bind(req *http.Request, obj interface{}) error { +func (b headerBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} +func (headerBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { if err := mapHeader(obj, req.Header); err != nil { return err } - - return validate(obj) + return validateCtx(ctx, obj) } func mapHeader(ptr interface{}, h map[string][]string) error { diff --git a/binding/json.go b/binding/json.go index 45aaa49487..d3e96be7dc 100644 --- a/binding/json.go +++ b/binding/json.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "errors" "io" "net/http" @@ -30,18 +31,26 @@ func (jsonBinding) Name() string { return "json" } -func (jsonBinding) Bind(req *http.Request, obj interface{}) error { +func (b jsonBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} + +func (jsonBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { if req == nil || req.Body == nil { return errors.New("invalid request") } - return decodeJSON(req.Body, obj) + return decodeJSON(ctx, req.Body, obj) +} + +func (b jsonBinding) BindBody(body []byte, obj interface{}) error { + return b.CtxBindBody(context.Background(), body, obj) } -func (jsonBinding) BindBody(body []byte, obj interface{}) error { - return decodeJSON(bytes.NewReader(body), obj) +func (jsonBinding) CtxBindBody(ctx context.Context, body []byte, obj interface{}) error { + return decodeJSON(ctx, bytes.NewReader(body), obj) } -func decodeJSON(r io.Reader, obj interface{}) error { +func decodeJSON(ctx context.Context, r io.Reader, obj interface{}) error { decoder := json.NewDecoder(r) if EnableDecoderUseNumber { decoder.UseNumber() @@ -52,5 +61,5 @@ func decodeJSON(r io.Reader, obj interface{}) error { if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/msgpack.go b/binding/msgpack.go index 2a442996a6..ce1682cc4a 100644 --- a/binding/msgpack.go +++ b/binding/msgpack.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -21,18 +22,26 @@ func (msgpackBinding) Name() string { return "msgpack" } -func (msgpackBinding) Bind(req *http.Request, obj interface{}) error { - return decodeMsgPack(req.Body, obj) +func (b msgpackBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) } -func (msgpackBinding) BindBody(body []byte, obj interface{}) error { - return decodeMsgPack(bytes.NewReader(body), obj) +func (msgpackBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeMsgPack(ctx, req.Body, obj) } -func decodeMsgPack(r io.Reader, obj interface{}) error { +func (b msgpackBinding) BindBody(body []byte, obj interface{}) error { + return b.CtxBindBody(context.Background(), body, obj) +} + +func (msgpackBinding) CtxBindBody(ctx context.Context, body []byte, obj interface{}) error { + return decodeMsgPack(ctx, bytes.NewReader(body), obj) +} + +func decodeMsgPack(ctx context.Context, r io.Reader, obj interface{}) error { cdc := new(codec.MsgpackHandle) if err := codec.NewDecoder(r, cdc).Decode(&obj); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/query.go b/binding/query.go index 219743f2a9..d53f8609f7 100644 --- a/binding/query.go +++ b/binding/query.go @@ -4,7 +4,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) type queryBinding struct{} @@ -12,10 +15,14 @@ func (queryBinding) Name() string { return "query" } -func (queryBinding) Bind(req *http.Request, obj interface{}) error { +func (b queryBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) +} + +func (queryBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { values := req.URL.Query() if err := mapForm(obj, values); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/uri.go b/binding/uri.go index a3c0df515c..ad07f1faf9 100644 --- a/binding/uri.go +++ b/binding/uri.go @@ -4,15 +4,21 @@ package binding +import "context" + type uriBinding struct{} func (uriBinding) Name() string { return "uri" } -func (uriBinding) BindUri(m map[string][]string, obj interface{}) error { +func (b uriBinding) BindUri(m map[string][]string, obj interface{}) error { + return b.CtxBindUri(context.Background(), m, obj) +} + +func (uriBinding) CtxBindUri(ctx context.Context, m map[string][]string, obj interface{}) error { if err := mapURI(obj, m); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/validate_test.go b/binding/validate_test.go index 5299fbf602..6d3cfdf1c1 100644 --- a/binding/validate_test.go +++ b/binding/validate_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "testing" "time" @@ -226,3 +227,7 @@ func TestValidatorEngine(t *testing.T) { // Check that the error matches expectation assert.Error(t, errs, "", "", "notone") } + +func validate(obj interface{}) error { + return validateCtx(context.Background(), obj) +} diff --git a/binding/xml.go b/binding/xml.go index 4e90114962..80276d85ec 100644 --- a/binding/xml.go +++ b/binding/xml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/xml" "io" "net/http" @@ -17,17 +18,26 @@ func (xmlBinding) Name() string { return "xml" } -func (xmlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeXML(req.Body, obj) +func (b xmlBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) } -func (xmlBinding) BindBody(body []byte, obj interface{}) error { - return decodeXML(bytes.NewReader(body), obj) +func (xmlBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeXML(ctx, req.Body, obj) } -func decodeXML(r io.Reader, obj interface{}) error { + +func (b xmlBinding) BindBody(body []byte, obj interface{}) error { + return b.CtxBindBody(context.Background(), body, obj) +} + +func (xmlBinding) CtxBindBody(ctx context.Context, body []byte, obj interface{}) error { + return decodeXML(ctx, bytes.NewReader(body), obj) +} + +func decodeXML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := xml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/binding/yaml.go b/binding/yaml.go index a2d36d6a54..a17d9d98a5 100644 --- a/binding/yaml.go +++ b/binding/yaml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -18,18 +19,26 @@ func (yamlBinding) Name() string { return "yaml" } -func (yamlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeYAML(req.Body, obj) +func (b yamlBinding) Bind(req *http.Request, obj interface{}) error { + return b.CtxBind(context.Background(), req, obj) } -func (yamlBinding) BindBody(body []byte, obj interface{}) error { - return decodeYAML(bytes.NewReader(body), obj) +func (yamlBinding) CtxBind(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeYAML(ctx, req.Body, obj) } -func decodeYAML(r io.Reader, obj interface{}) error { +func (b yamlBinding) BindBody(body []byte, obj interface{}) error { + return b.CtxBindBody(context.Background(), body, obj) +} + +func (yamlBinding) CtxBindBody(ctx context.Context, body []byte, obj interface{}) error { + return decodeYAML(ctx, bytes.NewReader(body), obj) +} + +func decodeYAML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := yaml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateCtx(ctx, obj) } diff --git a/context.go b/context.go index 8a2f46d1cb..5959b36bfc 100644 --- a/context.go +++ b/context.go @@ -702,7 +702,7 @@ func (c *Context) ShouldBindUri(obj interface{}) error { for _, v := range c.Params { m[v.Key] = []string{v.Value} } - return binding.Uri.BindUri(m, obj) + return binding.Uri.CtxBindUri(c, m, obj) } // ShouldBindWith binds the passed struct pointer using the specified binding engine.