diff --git a/go.mod b/go.mod index 50aba584b..8d93cf754 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-openapi/jsonpointer v0.19.5 github.com/gorilla/mux v1.8.0 github.com/invopop/yaml v0.1.0 + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/stretchr/testify v1.5.1 gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index a123aaff6..074de2aa8 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e h1:hB2xlXdHp/pmPZq0y3QnmWAArdw9PqbmotexnWx/FU8= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/openapi3/schema.go b/openapi3/schema.go index 9b36e604a..295180fdd 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -13,6 +13,7 @@ import ( "unicode/utf16" "github.com/go-openapi/jsonpointer" + "github.com/mohae/deepcopy" "github.com/getkin/kin-openapi/jsoninfo" ) @@ -915,9 +916,17 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val } } - ok := 0 - validationErrors := []error{} - for _, item := range v { + var ( + ok = 0 + validationErrors = []error{} + matchedOneOfIdx = 0 + tempValue = value + ) + // make a deep copy to protect origin value from being injected default value that defined in mismatched oneOf schema + if settings.asreq || settings.asrep { + tempValue = deepcopy.Copy(value) + } + for idx, item := range v { v := item.Value if v == nil { return foundUnresolvedRef(item.Ref) @@ -927,11 +936,12 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val continue } - if err := v.visitJSON(settings, value); err != nil { + if err := v.visitJSON(settings, tempValue); err != nil { validationErrors = append(validationErrors, err) continue } + matchedOneOfIdx = idx ok++ } @@ -962,17 +972,30 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return e } + + if settings.asreq || settings.asrep { + _ = v[matchedOneOfIdx].Value.visitJSON(settings, value) + } } if v := schema.AnyOf; len(v) > 0 { - ok := false - for _, item := range v { + var ( + ok = false + matchedAnyOfIdx = 0 + tempValue = value + ) + // make a deep copy to protect origin value from being injected default value that defined in mismatched anyOf schema + if settings.asreq || settings.asrep { + tempValue = deepcopy.Copy(value) + } + for idx, item := range v { v := item.Value if v == nil { return foundUnresolvedRef(item.Ref) } - if err := v.visitJSON(settings, value); err == nil { + if err := v.visitJSON(settings, tempValue); err == nil { ok = true + matchedAnyOfIdx = idx break } } @@ -986,6 +1009,8 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val SchemaField: "anyOf", } } + + _ = v[matchedAnyOfIdx].Value.visitJSON(settings, value) } for _, item := range schema.AllOf { diff --git a/openapi3filter/validate_set_default_test.go b/openapi3filter/validate_set_default_test.go index 40714051a..4550b51b2 100644 --- a/openapi3filter/validate_set_default_test.go +++ b/openapi3filter/validate_set_default_test.go @@ -245,6 +245,78 @@ func TestValidateRequestBodyAndSetDefault(t *testing.T) { } } } + }, + "social_network": { + "oneOf": [ + { + "type": "object", + "required": ["platform"], + "properties": { + "platform": { + "type": "string", + "enum": [ + "twitter" + ] + }, + "tw_link": { + "type": "string", + "default": "www.twitter.com" + } + } + }, + { + "type": "object", + "required": ["platform"], + "properties": { + "platform": { + "type": "string", + "enum": [ + "facebook" + ] + }, + "fb_link": { + "type": "string", + "default": "www.facebook.com" + } + } + } + ] + }, + "social_network_2": { + "anyOf": [ + { + "type": "object", + "required": ["platform"], + "properties": { + "platform": { + "type": "string", + "enum": [ + "twitter" + ] + }, + "tw_link": { + "type": "string", + "default": "www.twitter.com" + } + } + }, + { + "type": "object", + "required": ["platform"], + "properties": { + "platform": { + "type": "string", + "enum": [ + "facebook" + ] + }, + "fb_link": { + "type": "string", + "default": "www.facebook.com" + } + } + } + ] } } } @@ -281,13 +353,20 @@ func TestValidateRequestBodyAndSetDefault(t *testing.T) { OP string `json:"op,omitempty"` Value int `json:"value,omitempty"` } + type socialNetwork struct { + Platform string `json:"platform,omitempty"` + FBLink string `json:"fb_link,omitempty"` + TWLink string `json:"tw_link,omitempty"` + } type body struct { - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Code int `json:"code,omitempty"` - All bool `json:"all,omitempty"` - Page *page `json:"page,omitempty"` - Filters []filter `json:"filters,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Code int `json:"code,omitempty"` + All bool `json:"all,omitempty"` + Page *page `json:"page,omitempty"` + Filters []filter `json:"filters,omitempty"` + SocialNetwork *socialNetwork `json:"social_network,omitempty"` + SocialNetwork2 *socialNetwork `json:"social_network_2,omitempty"` } testCases := []struct { @@ -531,6 +610,52 @@ func TestValidateRequestBodyAndSetDefault(t *testing.T) { "value": 456 } ] +} + `, body) + }, + }, + { + name: "social_network(oneOf)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + SocialNetwork: &socialNetwork{ + Platform: "facebook", + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "social_network": { + "platform": "facebook", + "fb_link": "www.facebook.com" + } +} + `, body) + }, + }, + { + name: "social_network_2(anyOf)", + body: body{ + ID: "bt6kdc3d0cvp6u8u3ft0", + SocialNetwork2: &socialNetwork{ + Platform: "facebook", + }, + }, + bodyAssertion: func(t *testing.T, body string) { + require.JSONEq(t, ` +{ + "id": "bt6kdc3d0cvp6u8u3ft0", + "name": "default", + "code": 123, + "all": false, + "social_network_2": { + "platform": "facebook", + "fb_link": "www.facebook.com" + } } `, body) },