From cb9b68b1b576f0231b5383841c40d32d5d4ee8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Tue, 21 Sep 2021 21:54:03 +0200 Subject: [PATCH 1/4] Always return ValidationErrors, expose failing index This way callers can always expect ValidationErrors, and in case of slice validation, they can also get indexes of the failing elements. --- binding/default_validator.go | 58 ++++++++++----------- binding/default_validator_benchmark_test.go | 20 ------- binding/default_validator_test.go | 54 +++++++------------ 3 files changed, 48 insertions(+), 84 deletions(-) delete mode 100644 binding/default_validator_benchmark_test.go diff --git a/binding/default_validator.go b/binding/default_validator.go index 87fc4c665b..1622530210 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -7,7 +7,6 @@ package binding import ( "fmt" "reflect" - "strings" "sync" "github.com/go-playground/validator/v10" @@ -18,34 +17,33 @@ type defaultValidator struct { validate *validator.Validate } -type sliceValidateError []error +// SliceFieldError is returned for invalid slice or array elements. +// It extends validator.FieldError with the index of the failing element. +type SliceFieldError interface { + validator.FieldError + Index() int +} -// Error concatenates all error elements in sliceValidateError into a single string separated by \n. -func (err sliceValidateError) Error() string { - n := len(err) - switch n { - case 0: - return "" - default: - var b strings.Builder - if err[0] != nil { - fmt.Fprintf(&b, "[%d]: %s", 0, err[0].Error()) - } - if n > 1 { - for i := 1; i < n; i++ { - if err[i] != nil { - b.WriteString("\n") - fmt.Fprintf(&b, "[%d]: %s", i, err[i].Error()) - } - } - } - return b.String() - } +type sliceFieldError struct { + validator.FieldError + index int +} + +func (fe sliceFieldError) Index() int { + return fe.index +} + +func (fe sliceFieldError) Error() string { + return fmt.Sprintf("[%d]: %s", fe.index, fe.FieldError.Error()) +} + +func (fe sliceFieldError) Unwrap() error { + return fe.FieldError } var _ StructValidator = &defaultValidator{} -// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +// ValidateStruct receives any kind of type, but validates only structs, pointers, slices, and arrays. func (v *defaultValidator) ValidateStruct(obj interface{}) error { if obj == nil { return nil @@ -59,16 +57,18 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { return v.validateStruct(obj) case reflect.Slice, reflect.Array: count := value.Len() - validateRet := make(sliceValidateError, 0) + var errs validator.ValidationErrors for i := 0; i < count; i++ { if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { - validateRet = append(validateRet, err) + for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint + errs = append(errs, sliceFieldError{fieldError, i}) + } } } - if len(validateRet) == 0 { - return nil + if len(errs) > 0 { + return errs } - return validateRet + return nil default: return nil } diff --git a/binding/default_validator_benchmark_test.go b/binding/default_validator_benchmark_test.go deleted file mode 100644 index 839cf710b5..0000000000 --- a/binding/default_validator_benchmark_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package binding - -import ( - "errors" - "strconv" - "testing" -) - -func BenchmarkSliceValidateError(b *testing.B) { - const size int = 100 - for i := 0; i < b.N; i++ { - e := make(sliceValidateError, size) - for j := 0; j < size; j++ { - e[j] = errors.New(strconv.Itoa(j)) - } - if len(e.Error()) == 0 { - b.Errorf("error") - } - } -} diff --git a/binding/default_validator_test.go b/binding/default_validator_test.go index e9debe59bf..fec34ef672 100644 --- a/binding/default_validator_test.go +++ b/binding/default_validator_test.go @@ -7,43 +7,27 @@ package binding import ( "errors" "testing" + + "github.com/go-playground/validator/v10" + "github.com/stretchr/testify/assert" ) -func TestSliceValidateError(t *testing.T) { - tests := []struct { - name string - err sliceValidateError - want string - }{ - {"has nil elements", sliceValidateError{errors.New("test error"), nil}, "[0]: test error"}, - {"has zero elements", sliceValidateError{}, ""}, - {"has one element", sliceValidateError{errors.New("test one error")}, "[0]: test one error"}, - {"has two elements", - sliceValidateError{ - errors.New("first error"), - errors.New("second error"), - }, - "[0]: first error\n[1]: second error", - }, - {"has many elements", - sliceValidateError{ - errors.New("first error"), - errors.New("second error"), - nil, - nil, - nil, - errors.New("last error"), - }, - "[0]: first error\n[1]: second error\n[5]: last error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.err.Error(); got != tt.want { - t.Errorf("sliceValidateError.Error() = %v, want %v", got, tt.want) - } - }) - } +func TestSliceFieldError(t *testing.T) { + var fe validator.FieldError = dummyFieldError{msg: "test error"} + + var err SliceFieldError = sliceFieldError{fe, 10} + assert.Equal(t, 10, err.Index()) + assert.Equal(t, "[10]: test error", err.Error()) + assert.Equal(t, fe, errors.Unwrap(err)) +} + +type dummyFieldError struct { + validator.FieldError + msg string +} + +func (fe dummyFieldError) Error() string { + return fe.msg } func TestDefaultValidator(t *testing.T) { From 04ccf172a6f6401c09a8912a6ca6e6cd061205d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Fri, 5 Nov 2021 17:48:10 +0100 Subject: [PATCH 2/4] Add support for validating map types --- binding/binding.go | 6 ++--- binding/binding_nomsgpack.go | 5 ++-- binding/default_validator.go | 41 +++++++++++++++++++++++++++++-- binding/default_validator_test.go | 26 ++++++++++++++++++++ 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/binding/binding.go b/binding/binding.go index deb71661b8..7042101d5c 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -52,9 +52,9 @@ type BindingUri interface { // https://github.com/go-playground/validator/tree/v10.6.1. type StructValidator interface { // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. - // If the received type is a slice|array, the validation should be performed travel on every element. - // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. - // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the received type is a slice/array/map, the validation should be performed on every element. + // If the received type is not a struct or slice/array/map, any validation should be skipped and nil must be returned. + // If the received type is a pointer to a struct/slice/array/map, the validation should be performed. // If the struct is not valid or the validation itself fails, a descriptive error should be returned. // Otherwise nil must be returned. ValidateStruct(interface{}) error diff --git a/binding/binding_nomsgpack.go b/binding/binding_nomsgpack.go index 2342447070..00d6303619 100644 --- a/binding/binding_nomsgpack.go +++ b/binding/binding_nomsgpack.go @@ -50,8 +50,9 @@ type BindingUri interface { // https://github.com/go-playground/validator/tree/v10.6.1. type StructValidator interface { // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. - // If the received type is not a struct, any validation should be skipped and nil must be returned. - // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the received type is a slice/array/map, the validation should be performed on every element. + // If the received type is not a struct or slice/array/map, any validation should be skipped and nil must be returned. + // If the received type is a pointer to a struct/slice/array/map, the validation should be performed. // If the struct is not valid or the validation itself fails, a descriptive error should be returned. // Otherwise nil must be returned. ValidateStruct(interface{}) error diff --git a/binding/default_validator.go b/binding/default_validator.go index 1622530210..10e9bb122b 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -41,9 +41,33 @@ func (fe sliceFieldError) Unwrap() error { return fe.FieldError } +// MapFieldError is returned for invalid map values. +// It extends validator.FieldError with the key of the failing value. +type MapFieldError interface { + validator.FieldError + Key() interface{} +} + +type mapFieldError struct { + validator.FieldError + key interface{} +} + +func (fe mapFieldError) Key() interface{} { + return fe.key +} + +func (fe mapFieldError) Error() string { + return fmt.Sprintf("[%v]: %s", fe.key, fe.FieldError.Error()) +} + +func (fe mapFieldError) Unwrap() error { + return fe.FieldError +} + var _ StructValidator = &defaultValidator{} -// ValidateStruct receives any kind of type, but validates only structs, pointers, slices, and arrays. +// ValidateStruct receives any kind of type, but validates only structs, pointers, slices, arrays, and maps. func (v *defaultValidator) ValidateStruct(obj interface{}) error { if obj == nil { return nil @@ -56,8 +80,8 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { case reflect.Struct: return v.validateStruct(obj) case reflect.Slice, reflect.Array: - count := value.Len() var errs validator.ValidationErrors + count := value.Len() for i := 0; i < count; i++ { if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint @@ -69,6 +93,19 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { return errs } return nil + case reflect.Map: + var errs validator.ValidationErrors + for _, key := range value.MapKeys() { + if err := v.ValidateStruct(value.MapIndex(key).Interface()); err != nil { + for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint + errs = append(errs, mapFieldError{fieldError, key.Interface()}) + } + } + } + if len(errs) > 0 { + return errs + } + return nil default: return nil } diff --git a/binding/default_validator_test.go b/binding/default_validator_test.go index fec34ef672..51139fe140 100644 --- a/binding/default_validator_test.go +++ b/binding/default_validator_test.go @@ -21,6 +21,20 @@ func TestSliceFieldError(t *testing.T) { assert.Equal(t, fe, errors.Unwrap(err)) } +func TestMapFieldError(t *testing.T) { + var fe validator.FieldError = dummyFieldError{msg: "test error"} + + var err MapFieldError = mapFieldError{fe, "test key"} + assert.Equal(t, "test key", err.Key()) + assert.Equal(t, "[test key]: test error", err.Error()) + assert.Equal(t, fe, errors.Unwrap(err)) + + err = mapFieldError{fe, 123} + assert.Equal(t, 123, err.Key()) + assert.Equal(t, "[123]: test error", err.Error()) + assert.Equal(t, fe, errors.Unwrap(err)) +} + type dummyFieldError struct { validator.FieldError msg string @@ -61,6 +75,18 @@ func TestDefaultValidator(t *testing.T) { {"validate *[]*struct failed-1", &defaultValidator{}, &[]*exampleStruct{{A: "123456789", B: 1}}, true}, {"validate *[]*struct failed-2", &defaultValidator{}, &[]*exampleStruct{{A: "12345678", B: 0}}, true}, {"validate *[]*struct passed", &defaultValidator{}, &[]*exampleStruct{{A: "12345678", B: 1}}, false}, + {"validate map[string]struct failed-1", &defaultValidator{}, map[string]exampleStruct{"x": {A: "123456789", B: 1}}, true}, + {"validate map[string]struct failed-2", &defaultValidator{}, map[string]exampleStruct{"x": {A: "12345678", B: 0}}, true}, + {"validate map[string]struct passed", &defaultValidator{}, map[string]exampleStruct{"x": {A: "12345678", B: 1}}, false}, + {"validate map[string]*struct failed-1", &defaultValidator{}, map[string]*exampleStruct{"x": {A: "123456789", B: 1}}, true}, + {"validate map[string]*struct failed-2", &defaultValidator{}, map[string]*exampleStruct{"x": {A: "12345678", B: 0}}, true}, + {"validate map[string]*struct passed", &defaultValidator{}, map[string]*exampleStruct{"x": {A: "12345678", B: 1}}, false}, + {"validate *map[string]struct failed-1", &defaultValidator{}, &map[string]exampleStruct{"x": {A: "123456789", B: 1}}, true}, + {"validate *map[string]struct failed-2", &defaultValidator{}, &map[string]exampleStruct{"x": {A: "12345678", B: 0}}, true}, + {"validate *map[string]struct passed", &defaultValidator{}, &map[string]exampleStruct{"x": {A: "12345678", B: 1}}, false}, + {"validate *map[string]*struct failed-1", &defaultValidator{}, &map[string]*exampleStruct{"x": {A: "123456789", B: 1}}, true}, + {"validate *map[string]*struct failed-2", &defaultValidator{}, &map[string]*exampleStruct{"x": {A: "12345678", B: 0}}, true}, + {"validate *map[string]*struct passed", &defaultValidator{}, &map[string]*exampleStruct{"x": {A: "12345678", B: 1}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From f26790bbda2f510e9ca4b149f3c0dbe79eab4ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Fri, 5 Nov 2021 18:28:36 +0100 Subject: [PATCH 3/4] Implement registering validator tags for custom map and slice types --- README.md | 41 +++++++++++ binding/default_validator.go | 48 +++++++++++++ binding/default_validator_test.go | 113 ++++++++++++++++++++++++++++++ 3 files changed, 202 insertions(+) diff --git a/README.md b/README.md index cad746d62c..3c4fc60625 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ Gin is a web framework written in Go (Golang). It features a martini-like API wi - [Controlling Log output coloring](#controlling-log-output-coloring) - [Model binding and validation](#model-binding-and-validation) - [Custom Validators](#custom-validators) + - [Custom Map and Slice Validator Tags](#custom-map-and-slice-validator-tags) - [Only Bind Query String](#only-bind-query-string) - [Bind Query String or Post Data](#bind-query-string-or-post-data) - [Bind Uri](#bind-uri) @@ -838,6 +839,46 @@ $ curl "localhost:8085/bookable?check_in=2000-03-09&check_out=2000-03-10" [Struct level validations](https://github.com/go-playground/validator/releases/tag/v8.7) can also be registered this way. See the [struct-lvl-validation example](https://github.com/gin-gonic/examples/tree/master/struct-lvl-validations) to learn more. +### Custom Map and Slice Validator Tags + +It is possible to register validator tags for custom map and slice types. + +```go +package main + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" +) + +type Person struct { + FirstName string `json:"firstName" binding:"required,lte=64"` + LastName string `json:"lastName" binding:"required,lte=64"` +} + +type Managers map[string]Person + +func main() { + route := gin.Default() + + binding.RegisterValidatorTag("dive,keys,oneof=accounting finance operations,endkeys", Managers{}) + + route.POST("/managers", configureManagers) + route.Run(":8085") +} + +func configureManagers(c *gin.Context) { + var m Managers + if err := c.ShouldBindJSON(&m); err == nil { + c.JSON(http.StatusOK, gin.H{"message": "Manager configuration is valid!"}) + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } +} +``` + ### Only Bind Query String `ShouldBindQuery` function only binds the query params and not the post data. See the [detail information](https://github.com/gin-gonic/gin/issues/742#issuecomment-315953017). diff --git a/binding/default_validator.go b/binding/default_validator.go index 10e9bb122b..b60e3cf685 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -12,6 +12,33 @@ import ( "github.com/go-playground/validator/v10" ) +var validatorTags = make(map[reflect.Type]string) + +// RegisterValidatorTag registers a validator tag against a number of types. +// This allows defining validation for custom slice, array, and map types. For example: +// type CustomMap map[int]string +// ... +// binding.RegisterValidatorTag("gt=0", CustomMap{}) +// +// Do not use the "dive" tag (unless in conjunction with "keys"/"endkeys"). +// Slice/array/map elements are validated independently. +// +// This function will not have any effect is binding.Validator has been replaced. +// +// NOTE: This function is not thread-safe. It is intended that these all be registered prior to any validation. +func RegisterValidatorTag(tag string, types ...interface{}) { + for _, typ := range types { + t := reflect.TypeOf(typ) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Slice && t.Kind() != reflect.Array && t.Kind() != reflect.Map { + panic("validator tags can be registered only for slices, arrays, and maps") + } + validatorTags[t] = tag + } +} + type defaultValidator struct { once sync.Once validate *validator.Validate @@ -81,6 +108,13 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { return v.validateStruct(obj) case reflect.Slice, reflect.Array: var errs validator.ValidationErrors + + if tag, ok := validatorTags[value.Type()]; ok { + if err := v.validateVar(obj, tag); err != nil { + errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint + } + } + count := value.Len() for i := 0; i < count; i++ { if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { @@ -89,12 +123,20 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { } } } + if len(errs) > 0 { return errs } return nil case reflect.Map: var errs validator.ValidationErrors + + if tag, ok := validatorTags[value.Type()]; ok { + if err := v.validateVar(obj, tag); err != nil { + errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint + } + } + for _, key := range value.MapKeys() { if err := v.ValidateStruct(value.MapIndex(key).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint @@ -117,6 +159,12 @@ func (v *defaultValidator) validateStruct(obj interface{}) error { return v.validate.Struct(obj) } +// validateStruct receives slice, array, and map types +func (v *defaultValidator) validateVar(obj interface{}, tag string) error { + v.lazyinit() + return v.validate.Var(obj, tag) +} + // Engine returns the underlying validator engine which powers the default // Validator instance. This is useful if you want to register custom validations // or struct level validations. See validator GoDoc for more info - diff --git a/binding/default_validator_test.go b/binding/default_validator_test.go index 51139fe140..cd502d9e11 100644 --- a/binding/default_validator_test.go +++ b/binding/default_validator_test.go @@ -96,3 +96,116 @@ func TestDefaultValidator(t *testing.T) { }) } } + +func TestRegisterValidatorTag(t *testing.T) { + type CustomSlice []struct { + A string + } + type CustomArray [10]struct { + A string + } + type CustomMap map[string]struct { + A string + } + type CustomStruct struct { + A string + } + type CustomInt int + + // only slice, array, and map types are accepted + RegisterValidatorTag("gt=0", CustomSlice{}) + RegisterValidatorTag("gt=0", &CustomSlice{}) + RegisterValidatorTag("gt=0", CustomArray{}) + RegisterValidatorTag("gt=0", &CustomArray{}) + RegisterValidatorTag("gt=0", CustomMap{}) + RegisterValidatorTag("gt=0", &CustomMap{}) + assert.Panics(t, func() { RegisterValidatorTag("gt=0", CustomStruct{}) }) + assert.Panics(t, func() { RegisterValidatorTag("gt=0", &CustomStruct{}) }) + assert.Panics(t, func() { var i CustomInt; RegisterValidatorTag("gt=0", i) }) + assert.Panics(t, func() { var i CustomInt; RegisterValidatorTag("gt=0", &i) }) +} + +func TestValidatorTagsSlice(t *testing.T) { + type CustomSlice []struct { + A string `binding:"max=8"` + } + + var ( + invalidSlice = CustomSlice{{"12345678"}} + invalidVal = CustomSlice{{"123456789"}, {"abcdefgh"}} + validSlice = CustomSlice{{"12345678"}, {"abcdefgh"}} + invalidSliceVal = CustomSlice{{"123456789"}} + ) + + v := &defaultValidator{} + + // no tags registered for the slice itself yet, so only elements are validated + assert.NoError(t, v.ValidateStruct(invalidSlice)) + assert.Error(t, v.ValidateStruct(invalidVal)) + assert.NoError(t, v.ValidateStruct(validSlice)) + assert.NoError(t, v.ValidateStruct(&invalidSlice)) + assert.Error(t, v.ValidateStruct(&invalidVal)) + assert.NoError(t, v.ValidateStruct(&validSlice)) + + err := v.ValidateStruct(invalidSliceVal) + assert.Error(t, err) + assert.Len(t, err, 1) // only value error + + RegisterValidatorTag("gt=1", CustomSlice{}) + + assert.Error(t, v.ValidateStruct(invalidSlice)) + assert.Error(t, v.ValidateStruct(invalidVal)) + assert.NoError(t, v.ValidateStruct(validSlice)) + assert.Error(t, v.ValidateStruct(&invalidSlice)) + assert.Error(t, v.ValidateStruct(&invalidVal)) + assert.NoError(t, v.ValidateStruct(&validSlice)) + + err = v.ValidateStruct(invalidSliceVal) + assert.Error(t, err) + assert.Len(t, err, 2) // both slice length and value error +} + +func TestValidatorTagsMap(t *testing.T) { + type CustomMap map[string]struct { + B int `binding:"gt=0"` + } + + var ( + invalidMap = CustomMap{"12345678": {1}} + invalidKey = CustomMap{"123456789": {1}, "abcdefgh": {2}} + invalidVal = CustomMap{"12345678": {0}, "abcdefgh": {2}} + invalidMapVal = CustomMap{"12345678": {0}} + validMap = CustomMap{"12345678": {1}, "abcdefgh": {2}} + ) + + v := &defaultValidator{} + + // no tags registered for the map itself yet, so only values are validated + assert.NoError(t, v.ValidateStruct(invalidMap)) + assert.NoError(t, v.ValidateStruct(invalidKey)) + assert.Error(t, v.ValidateStruct(invalidVal)) + assert.NoError(t, v.ValidateStruct(validMap)) + assert.NoError(t, v.ValidateStruct(&invalidMap)) + assert.NoError(t, v.ValidateStruct(&invalidKey)) + assert.Error(t, v.ValidateStruct(&invalidVal)) + assert.NoError(t, v.ValidateStruct(&validMap)) + + err := v.ValidateStruct(invalidMapVal) + assert.Error(t, err) + assert.Len(t, err, 1) // only value error + + RegisterValidatorTag("gt=1,dive,keys,max=8,endkeys", CustomMap{}) + + assert.Error(t, v.ValidateStruct(invalidMap)) + assert.Error(t, v.ValidateStruct(invalidKey)) + assert.Error(t, v.ValidateStruct(invalidVal)) + assert.NoError(t, v.ValidateStruct(validMap)) + assert.Error(t, v.ValidateStruct(&invalidMap)) + assert.Error(t, v.ValidateStruct(&invalidKey)) + assert.Error(t, v.ValidateStruct(&invalidVal)) + assert.NoError(t, v.ValidateStruct(&validMap)) + + err = v.ValidateStruct(invalidMapVal) + assert.Error(t, err) + assert.Len(t, err, 2) // both map size and value errors +} From b626f639064cf3a72e4698c7804ecef2b06b7ec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Fri, 5 Nov 2021 18:16:37 +0100 Subject: [PATCH 4/4] Add support for contextual validation By passing the Gin context to bindings, custom validators can take advantage of the information in the context. --- binding/binding.go | 45 ++++++++++-- binding/binding_msgpack_test.go | 15 +++- binding/binding_nomsgpack.go | 47 +++++++++++-- binding/binding_test.go | 75 +++++++++++++++++--- binding/default_validator.go | 27 +++++--- binding/form.go | 26 +++++-- binding/header.go | 9 ++- binding/json.go | 21 ++++-- binding/msgpack.go | 21 ++++-- binding/query.go | 13 +++- binding/uri.go | 10 ++- binding/validate_test.go | 5 ++ binding/xml.go | 22 ++++-- binding/yaml.go | 21 ++++-- context.go | 8 ++- context_test.go | 118 ++++++++++++++++++++++++++++++++ 16 files changed, 410 insertions(+), 73 deletions(-) diff --git a/binding/binding.go b/binding/binding.go index 7042101d5c..b400e6c5c0 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -7,7 +7,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) // Content-Type MIME of the most common data formats. const ( @@ -32,20 +35,41 @@ type Binding interface { Bind(*http.Request, interface{}) error } -// BindingBody adds BindBody method to Binding. BindBody is similar with Bind, +// ContextBinding enables contextual validation by adding BindContext to Binding. +// Custom validators can take advantage of the information in the context. +type ContextBinding interface { + Binding + BindContext(context.Context, *http.Request, interface{}) error +} + +// BindingBody adds BindBody method to Binding. BindBody is similar to Bind, // but it reads the body from supplied bytes instead of req.Body. type BindingBody interface { Binding BindBody([]byte, interface{}) error } -// BindingUri adds BindUri method to Binding. BindUri is similar with Bind, -// but it read the Params. +// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody. +// Custom validators can take advantage of the information in the context. +type ContextBindingBody interface { + BindingBody + BindContext(context.Context, *http.Request, interface{}) error + BindBodyContext(context.Context, []byte, interface{}) error +} + +// BindingUri is similar to Bind, but it read the Params. type BindingUri interface { Name() string BindUri(map[string][]string, interface{}) error } +// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri. +// Custom validators can take advantage of the information in the context. +type ContextBindingUri interface { + BindingUri + BindUriContext(context.Context, map[string][]string, interface{}) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using @@ -64,6 +88,14 @@ type StructValidator interface { Engine() interface{} } +// ContextStructValidator is an extension of StructValidator that requires implementing +// context-aware validation. +// Custom validators can take advantage of the information in the context. +type ContextStructValidator interface { + StructValidator + ValidateStructContext(context.Context, interface{}) error +} + // Validator is the default validator which implements the StructValidator // interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 // under the hood. @@ -110,9 +142,12 @@ func Default(method, contentType string) Binding { } } -func validate(obj interface{}) error { +func validateContext(ctx context.Context, obj interface{}) error { if Validator == nil { return nil } + if v, ok := Validator.(ContextStructValidator); ok { + return v.ValidateStructContext(ctx, obj) + } return Validator.ValidateStruct(obj) } diff --git a/binding/binding_msgpack_test.go b/binding/binding_msgpack_test.go index 04d9407971..7bc6d47dc7 100644 --- a/binding/binding_msgpack_test.go +++ b/binding/binding_msgpack_test.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -35,7 +36,7 @@ func TestBindingMsgPack(t *testing.T) { string(data), string(data[1:])) } -func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testMsgPackBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -48,7 +49,17 @@ func testMsgPackBodyBinding(t *testing.T, b Binding, name, path, badPath, body, obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEMSGPACK) - err = MsgPack.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + req.Header.Add("Content-Type", MIMEMSGPACK) + err = b.BindContext(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } diff --git a/binding/binding_nomsgpack.go b/binding/binding_nomsgpack.go index 00d6303619..f9b6f5111b 100644 --- a/binding/binding_nomsgpack.go +++ b/binding/binding_nomsgpack.go @@ -7,7 +7,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) // Content-Type MIME of the most common data formats. const ( @@ -30,20 +33,41 @@ type Binding interface { Bind(*http.Request, interface{}) error } -// BindingBody adds BindBody method to Binding. BindBody is similar with Bind, +// ContextBinding enables contextual validation by adding BindContext to Binding. +// Custom validators can take advantage of the information in the context. +type ContextBinding interface { + Binding + BindContext(context.Context, *http.Request, interface{}) error +} + +// BindingBody adds BindBody method to Binding. BindBody is similar to Bind, // but it reads the body from supplied bytes instead of req.Body. type BindingBody interface { Binding BindBody([]byte, interface{}) error } -// BindingUri adds BindUri method to Binding. BindUri is similar with Bind, -// but it read the Params. +// ContextBindingBody enables contextual validation by adding BindBodyContext to BindingBody. +// Custom validators can take advantage of the information in the context. +type ContextBindingBody interface { + BindingBody + BindContext(context.Context, *http.Request, interface{}) error + BindBodyContext(context.Context, []byte, interface{}) error +} + +// BindingUri is similar to Bind, but it read the Params. type BindingUri interface { Name() string BindUri(map[string][]string, interface{}) error } +// ContextBindingUri enables contextual validation by adding BindUriContext to BindingUri. +// Custom validators can take advantage of the information in the context. +type ContextBindingUri interface { + BindingUri + BindUriContext(context.Context, map[string][]string, interface{}) error +} + // StructValidator is the minimal interface which needs to be implemented in // order for it to be used as the validator engine for ensuring the correctness // of the request. Gin provides a default implementation for this using @@ -62,6 +86,14 @@ type StructValidator interface { Engine() interface{} } +// ContextStructValidator is an extension of StructValidator that requires implementing +// context-aware validation. +// Custom validators can take advantage of the information in the context. +type ContextStructValidator interface { + StructValidator + ValidateStructContext(context.Context, interface{}) error +} + // Validator is the default validator which implements the StructValidator // interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 // under the hood. @@ -85,7 +117,7 @@ var ( // Default returns the appropriate Binding instance based on the HTTP method // and the content type. func Default(method, contentType string) Binding { - if method == "GET" { + if method == http.MethodGet { return Form } @@ -105,9 +137,12 @@ func Default(method, contentType string) Binding { } } -func validate(obj interface{}) error { +func validateContext(ctx context.Context, obj interface{}) error { if Validator == nil { return nil } + if v, ok := Validator.(ContextStructValidator); ok { + return v.ValidateStructContext(ctx, obj) + } return Validator.ValidateStruct(obj) } diff --git a/binding/binding_test.go b/binding/binding_test.go index 5b0ce39d3e..c1d449a0c3 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -20,6 +21,7 @@ import ( "time" "github.com/gin-gonic/gin/testdata/protoexample" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) @@ -38,6 +40,10 @@ type FooStruct struct { Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"` } +type ConditionalFooStruct struct { + Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required_if_condition,max=32"` +} + type FooBarStruct struct { FooStruct Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" binding:"required"` @@ -144,6 +150,16 @@ type FooStructForMapPtrType struct { PtrBar *map[string]interface{} `form:"ptr_bar"` } +func init() { + _ = Validator.Engine().(*validator.Validate).RegisterValidationCtx( + "required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool { + if ctx.Value("condition") == true { + return !fl.Field().IsZero() + } + return true + }) +} + func TestBindingDefault(t *testing.T) { assert.Equal(t, Form, Default("GET", "")) assert.Equal(t, Form, Default("GET", MIMEJSON)) @@ -796,6 +812,38 @@ func TestUriBinding(t *testing.T) { assert.Equal(t, map[string]interface{}(nil), not.Name) } +func TestUriBindingWithContext(t *testing.T) { + b := Uri + + type Tag struct { + Name string `uri:"name" binding:"required_if_condition"` + } + + empty := make(map[string][]string) + assert.NoError(t, b.BindUriContext(context.Background(), empty, new(Tag))) + assert.Error(t, b.BindUriContext(context.WithValue(context.Background(), "condition", true), empty, new(Tag))) // nolint +} + +func TestUriBindingWithNotContextValidator(t *testing.T) { + prev := Validator + defer func() { + Validator = prev + }() + Validator = ¬ContextValidator{} + + TestUriBinding(t) +} + +type notContextValidator defaultValidator + +func (v *notContextValidator) ValidateStruct(obj interface{}) error { + return (*defaultValidator)(v).ValidateStruct(obj) +} + +func (v *notContextValidator) Engine() interface{} { + return (*defaultValidator)(v).Engine() +} + func TestUriInnerBinding(t *testing.T) { type Tag struct { Name string `uri:"name"` @@ -1179,7 +1227,7 @@ func testQueryBindingBoolFail(t *testing.T, method, path, badPath, body, badBody assert.Error(t, err) } -func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { +func testBodyBinding(t *testing.T, b ContextBinding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) obj := FooStruct{} @@ -1190,7 +1238,16 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) + assert.Error(t, err) + + obj2 := ConditionalFooStruct{} + req = requestWithBody("POST", path, body) + err = b.BindContext(context.Background(), req, &obj2) + assert.NoError(t, err) + assert.Equal(t, "bar", obj2.Foo) + + err = b.BindContext(context.WithValue(context.Background(), "condition", true), req, &obj2) // nolint assert.Error(t, err) } @@ -1204,7 +1261,7 @@ func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, ba var obj2 []FooStruct req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj2) + err = b.Bind(req, &obj2) assert.Error(t, err) } @@ -1249,7 +1306,7 @@ func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1267,7 +1324,7 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod obj = FooStructUseNumber{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1285,7 +1342,7 @@ func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath obj = FooStructDisallowUnknownFields{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) assert.Contains(t, err.Error(), "what") } @@ -1301,7 +1358,7 @@ func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, bad obj = FooStruct{} req = requestWithBody("POST", badPath, badBody) - err = JSON.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1318,7 +1375,7 @@ func testProtoBodyBinding(t *testing.T, b Binding, name, path, badPath, body, ba obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } @@ -1349,7 +1406,7 @@ func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body obj = protoexample.Test{} req = requestWithBody("POST", badPath, badBody) req.Header.Add("Content-Type", MIMEPROTOBUF) - err = ProtoBuf.Bind(req, &obj) + err = b.Bind(req, &obj) assert.Error(t, err) } diff --git a/binding/default_validator.go b/binding/default_validator.go index b60e3cf685..90a458d331 100644 --- a/binding/default_validator.go +++ b/binding/default_validator.go @@ -5,6 +5,7 @@ package binding import ( + "context" "fmt" "reflect" "sync" @@ -92,10 +93,14 @@ func (fe mapFieldError) Unwrap() error { return fe.FieldError } -var _ StructValidator = &defaultValidator{} +var _ ContextStructValidator = &defaultValidator{} // ValidateStruct receives any kind of type, but validates only structs, pointers, slices, arrays, and maps. func (v *defaultValidator) ValidateStruct(obj interface{}) error { + return v.ValidateStructContext(context.Background(), obj) +} + +func (v *defaultValidator) ValidateStructContext(ctx context.Context, obj interface{}) error { if obj == nil { return nil } @@ -103,21 +108,21 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { value := reflect.ValueOf(obj) switch value.Kind() { case reflect.Ptr: - return v.ValidateStruct(value.Elem().Interface()) + return v.ValidateStructContext(ctx, value.Elem().Interface()) case reflect.Struct: - return v.validateStruct(obj) + return v.validateStruct(ctx, obj) case reflect.Slice, reflect.Array: var errs validator.ValidationErrors if tag, ok := validatorTags[value.Type()]; ok { - if err := v.validateVar(obj, tag); err != nil { + if err := v.validateVar(ctx, obj, tag); err != nil { errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint } } count := value.Len() for i := 0; i < count; i++ { - if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { + if err := v.ValidateStructContext(ctx, value.Index(i).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint errs = append(errs, sliceFieldError{fieldError, i}) } @@ -132,13 +137,13 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { var errs validator.ValidationErrors if tag, ok := validatorTags[value.Type()]; ok { - if err := v.validateVar(obj, tag); err != nil { + if err := v.validateVar(ctx, obj, tag); err != nil { errs = append(errs, err.(validator.ValidationErrors)...) // nolint: errorlint } } for _, key := range value.MapKeys() { - if err := v.ValidateStruct(value.MapIndex(key).Interface()); err != nil { + if err := v.ValidateStructContext(ctx, value.MapIndex(key).Interface()); err != nil { for _, fieldError := range err.(validator.ValidationErrors) { // nolint: errorlint errs = append(errs, mapFieldError{fieldError, key.Interface()}) } @@ -154,15 +159,15 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { } // validateStruct receives struct type -func (v *defaultValidator) validateStruct(obj interface{}) error { +func (v *defaultValidator) validateStruct(ctx context.Context, obj interface{}) error { v.lazyinit() - return v.validate.Struct(obj) + return v.validate.StructCtx(ctx, obj) } // validateStruct receives slice, array, and map types -func (v *defaultValidator) validateVar(obj interface{}, tag string) error { +func (v *defaultValidator) validateVar(ctx context.Context, obj interface{}, tag string) error { v.lazyinit() - return v.validate.Var(obj, tag) + return v.validate.VarCtx(ctx, obj, tag) } // Engine returns the underlying validator engine which powers the default diff --git a/binding/form.go b/binding/form.go index fa2a6540a0..5020bfc272 100644 --- a/binding/form.go +++ b/binding/form.go @@ -5,6 +5,7 @@ package binding import ( + "context" "errors" "net/http" ) @@ -19,7 +20,11 @@ func (formBinding) Name() string { return "form" } -func (formBinding) Bind(req *http.Request, obj interface{}) error { +func (b formBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } @@ -29,34 +34,41 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error { if err := mapForm(obj, req.Form); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } func (formPostBinding) Name() string { return "form-urlencoded" } -func (formPostBinding) Bind(req *http.Request, obj interface{}) error { +func (b formPostBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formPostBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseForm(); err != nil { return err } if err := mapForm(obj, req.PostForm); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } func (formMultipartBinding) Name() string { return "multipart/form-data" } -func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { +func (b formMultipartBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (formMultipartBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := req.ParseMultipartForm(defaultMemory); err != nil { return err } if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil { return err } - - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/header.go b/binding/header.go index b99302af82..55b90dfbb2 100644 --- a/binding/header.go +++ b/binding/header.go @@ -1,6 +1,7 @@ package binding import ( + "context" "net/http" "net/textproto" "reflect" @@ -12,13 +13,15 @@ func (headerBinding) Name() string { return "header" } -func (headerBinding) Bind(req *http.Request, obj interface{}) error { +func (b headerBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} +func (headerBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if err := mapHeader(obj, req.Header); err != nil { return err } - - return validate(obj) + return validateContext(ctx, obj) } func mapHeader(ptr interface{}, h map[string][]string) error { diff --git a/binding/json.go b/binding/json.go index 45aaa49487..3f28c7ad25 100644 --- a/binding/json.go +++ b/binding/json.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "errors" "io" "net/http" @@ -30,18 +31,26 @@ func (jsonBinding) Name() string { return "json" } -func (jsonBinding) Bind(req *http.Request, obj interface{}) error { +func (b jsonBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (jsonBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { if req == nil || req.Body == nil { return errors.New("invalid request") } - return decodeJSON(req.Body, obj) + return decodeJSON(ctx, req.Body, obj) +} + +func (b jsonBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) } -func (jsonBinding) BindBody(body []byte, obj interface{}) error { - return decodeJSON(bytes.NewReader(body), obj) +func (jsonBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeJSON(ctx, bytes.NewReader(body), obj) } -func decodeJSON(r io.Reader, obj interface{}) error { +func decodeJSON(ctx context.Context, r io.Reader, obj interface{}) error { decoder := json.NewDecoder(r) if EnableDecoderUseNumber { decoder.UseNumber() @@ -52,5 +61,5 @@ func decodeJSON(r io.Reader, obj interface{}) error { if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/msgpack.go b/binding/msgpack.go index 2a442996a6..0d93bb3b6d 100644 --- a/binding/msgpack.go +++ b/binding/msgpack.go @@ -9,6 +9,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -21,18 +22,26 @@ func (msgpackBinding) Name() string { return "msgpack" } -func (msgpackBinding) Bind(req *http.Request, obj interface{}) error { - return decodeMsgPack(req.Body, obj) +func (b msgpackBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (msgpackBinding) BindBody(body []byte, obj interface{}) error { - return decodeMsgPack(bytes.NewReader(body), obj) +func (msgpackBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeMsgPack(ctx, req.Body, obj) } -func decodeMsgPack(r io.Reader, obj interface{}) error { +func (b msgpackBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (msgpackBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeMsgPack(ctx, bytes.NewReader(body), obj) +} + +func decodeMsgPack(ctx context.Context, r io.Reader, obj interface{}) error { cdc := new(codec.MsgpackHandle) if err := codec.NewDecoder(r, cdc).Decode(&obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/query.go b/binding/query.go index 219743f2a9..9fe1136895 100644 --- a/binding/query.go +++ b/binding/query.go @@ -4,7 +4,10 @@ package binding -import "net/http" +import ( + "context" + "net/http" +) type queryBinding struct{} @@ -12,10 +15,14 @@ func (queryBinding) Name() string { return "query" } -func (queryBinding) Bind(req *http.Request, obj interface{}) error { +func (b queryBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) +} + +func (queryBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { values := req.URL.Query() if err := mapForm(obj, values); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/uri.go b/binding/uri.go index a3c0df515c..dd6cf655cf 100644 --- a/binding/uri.go +++ b/binding/uri.go @@ -4,15 +4,21 @@ package binding +import "context" + type uriBinding struct{} func (uriBinding) Name() string { return "uri" } -func (uriBinding) BindUri(m map[string][]string, obj interface{}) error { +func (b uriBinding) BindUri(m map[string][]string, obj interface{}) error { + return b.BindUriContext(context.Background(), m, obj) +} + +func (uriBinding) BindUriContext(ctx context.Context, m map[string][]string, obj interface{}) error { if err := mapURI(obj, m); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/validate_test.go b/binding/validate_test.go index 5299fbf602..c05bbdc896 100644 --- a/binding/validate_test.go +++ b/binding/validate_test.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "testing" "time" @@ -226,3 +227,7 @@ func TestValidatorEngine(t *testing.T) { // Check that the error matches expectation assert.Error(t, errs, "", "", "notone") } + +func validate(obj interface{}) error { + return validateContext(context.Background(), obj) +} diff --git a/binding/xml.go b/binding/xml.go index 4e90114962..51d2f11053 100644 --- a/binding/xml.go +++ b/binding/xml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "encoding/xml" "io" "net/http" @@ -17,17 +18,26 @@ func (xmlBinding) Name() string { return "xml" } -func (xmlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeXML(req.Body, obj) +func (b xmlBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (xmlBinding) BindBody(body []byte, obj interface{}) error { - return decodeXML(bytes.NewReader(body), obj) +func (xmlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeXML(ctx, req.Body, obj) } -func decodeXML(r io.Reader, obj interface{}) error { + +func (b xmlBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (xmlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeXML(ctx, bytes.NewReader(body), obj) +} + +func decodeXML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := xml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/binding/yaml.go b/binding/yaml.go index a2d36d6a54..816f67e5b6 100644 --- a/binding/yaml.go +++ b/binding/yaml.go @@ -6,6 +6,7 @@ package binding import ( "bytes" + "context" "io" "net/http" @@ -18,18 +19,26 @@ func (yamlBinding) Name() string { return "yaml" } -func (yamlBinding) Bind(req *http.Request, obj interface{}) error { - return decodeYAML(req.Body, obj) +func (b yamlBinding) Bind(req *http.Request, obj interface{}) error { + return b.BindContext(context.Background(), req, obj) } -func (yamlBinding) BindBody(body []byte, obj interface{}) error { - return decodeYAML(bytes.NewReader(body), obj) +func (yamlBinding) BindContext(ctx context.Context, req *http.Request, obj interface{}) error { + return decodeYAML(ctx, req.Body, obj) } -func decodeYAML(r io.Reader, obj interface{}) error { +func (b yamlBinding) BindBody(body []byte, obj interface{}) error { + return b.BindBodyContext(context.Background(), body, obj) +} + +func (yamlBinding) BindBodyContext(ctx context.Context, body []byte, obj interface{}) error { + return decodeYAML(ctx, bytes.NewReader(body), obj) +} + +func decodeYAML(ctx context.Context, r io.Reader, obj interface{}) error { decoder := yaml.NewDecoder(r) if err := decoder.Decode(obj); err != nil { return err } - return validate(obj) + return validateContext(ctx, obj) } diff --git a/context.go b/context.go index 58f38c88cb..e98fee72b2 100644 --- a/context.go +++ b/context.go @@ -704,12 +704,15 @@ func (c *Context) ShouldBindUri(obj interface{}) error { for _, v := range c.Params { m[v.Key] = []string{v.Value} } - return binding.Uri.BindUri(m, obj) + return binding.Uri.BindUriContext(c, m, obj) } // ShouldBindWith binds the passed struct pointer using the specified binding engine. // See the binding package. func (c *Context) ShouldBindWith(obj interface{}, b binding.Binding) error { + if b, ok := b.(binding.ContextBinding); ok { + return b.BindContext(c, c.Request, obj) + } return b.Bind(c.Request, obj) } @@ -732,6 +735,9 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e } c.Set(BodyBytesKey, body) } + if bb, ok := bb.(binding.ContextBindingBody); ok { + return bb.BindBodyContext(c, body, obj) + } return bb.BindBody(body, obj) } diff --git a/context_test.go b/context_test.go index c286c0f4cb..1e30af29a6 100644 --- a/context_test.go +++ b/context_test.go @@ -24,6 +24,7 @@ import ( "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" testdata "github.com/gin-gonic/gin/testdata/protoexample" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) @@ -36,6 +37,16 @@ var _ context.Context = &Context{} // BAD case: func (c *Context) Render(code int, render render.Render, obj ...interface{}) { // test that information is not leaked when reusing Contexts (using the Pool) +func init() { + _ = binding.Validator.Engine().(*validator.Validate).RegisterValidationCtx( + "required_if_condition", func(ctx context.Context, fl validator.FieldLevel) bool { + if ctx.Value("condition") == true { + return !fl.Field().IsZero() + } + return true + }) +} + func createMultipartRequest() *http.Request { boundary := "--testboundary" body := new(bytes.Buffer) @@ -1543,6 +1554,27 @@ func TestContextBindWithJSON(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextBindWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + assert.Error(t, c.BindJSON(&obj)) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.BindJSON(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + func TestContextBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -1672,6 +1704,92 @@ func TestContextShouldBindWithJSON(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextShouldBindWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + assert.Error(t, c.ShouldBindJSON(&obj)) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindJSON(&obj)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindBodyWithJSONContextual(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + c.Set("condition", true) + c.Set(BodyBytesKey, []byte("{\"bar\":\"foo\"}")) + assert.Error(t, c.ShouldBindBodyWith(&obj, binding.JSON)) + + c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindBodyWith(&obj, binding.JSON)) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindWithNotContextBinding(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + c.Request.Header.Add("Content-Type", MIMEXML) // set fake content-type + + var obj struct { + Foo string `json:"foo" binding:"required_if_condition"` + Bar string `json:"bar"` + } + assert.NoError(t, c.ShouldBindWith(&obj, notContextBinding{})) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +func TestContextShouldBindBodyWithNotContextBinding(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + var obj struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + c.Set(BodyBytesKey, []byte("{\"foo\":\"bar\", \"bar\":\"foo\"}")) + assert.NoError(t, c.ShouldBindBodyWith(&obj, notContextBinding{})) + assert.Equal(t, "foo", obj.Bar) + assert.Equal(t, "bar", obj.Foo) + assert.Equal(t, 0, w.Body.Len()) +} + +type notContextBinding struct{} + +func (notContextBinding) Name() string { + return binding.JSON.Name() +} + +func (b notContextBinding) Bind(req *http.Request, obj interface{}) error { + return binding.JSON.Bind(req, obj) +} + +func (b notContextBinding) BindBody(body []byte, obj interface{}) error { + return binding.JSON.BindBody(body, obj) +} + func TestContextShouldBindWithXML(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w)