diff --git a/openapi3/schema.go b/openapi3/schema.go index 1d77e816a..600cb9de2 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -861,10 +861,11 @@ func (schema *Schema) visitJSON(settings *schemaValidationSettings, value interf } } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "type", - Reason: fmt.Sprintf("unhandled value of type %T", value), + Value: value, + Schema: schema, + SchemaField: "type", + Reason: fmt.Sprintf("unhandled value of type %T", value), + customizeMessageError: settings.customizeMessageError, } } @@ -879,10 +880,11 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "enum", - Reason: "value is not one of the allowed values", + Value: value, + Schema: schema, + SchemaField: "enum", + Reason: "value is not one of the allowed values", + customizeMessageError: settings.customizeMessageError, } } @@ -896,9 +898,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "not", + Value: value, + Schema: schema, + SchemaField: "not", + customizeMessageError: settings.customizeMessageError, } } } @@ -961,9 +964,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } e := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "oneOf", + Value: value, + Schema: schema, + SchemaField: "oneOf", + customizeMessageError: settings.customizeMessageError, } if ok > 1 { e.Origin = ErrOneOfConflict @@ -1005,9 +1009,10 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "anyOf", + Value: value, + Schema: schema, + SchemaField: "anyOf", + customizeMessageError: settings.customizeMessageError, } } @@ -1024,10 +1029,11 @@ func (schema *Schema) visitSetOperations(settings *schemaValidationSettings, val return errSchema } return &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "allOf", - Origin: err, + Value: value, + Schema: schema, + SchemaField: "allOf", + Origin: err, + customizeMessageError: settings.customizeMessageError, } } } @@ -1042,10 +1048,11 @@ func (schema *Schema) visitJSONNull(settings *schemaValidationSettings) (err err return errSchema } return &SchemaError{ - Value: nil, - Schema: schema, - SchemaField: "nullable", - Reason: "Value is not nullable", + Value: nil, + Schema: schema, + SchemaField: "nullable", + Reason: "Value is not nullable", + customizeMessageError: settings.customizeMessageError, } } @@ -1075,10 +1082,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "type", - Reason: "Value must be an integer", + Value: value, + Schema: schema, + SchemaField: "type", + Reason: "Value must be an integer", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1110,10 +1118,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "format", - Reason: fmt.Sprintf("number must be an %s", schema.Format), + Value: value, + Schema: schema, + SchemaField: "format", + Reason: fmt.Sprintf("number must be an %s", schema.Format), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1128,10 +1137,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "exclusiveMinimum", - Reason: fmt.Sprintf("number must be more than %g", *schema.Min), + Value: value, + Schema: schema, + SchemaField: "exclusiveMinimum", + Reason: fmt.Sprintf("number must be more than %g", *schema.Min), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1145,10 +1155,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "exclusiveMaximum", - Reason: fmt.Sprintf("number must be less than %g", *schema.Max), + Value: value, + Schema: schema, + SchemaField: "exclusiveMaximum", + Reason: fmt.Sprintf("number must be less than %g", *schema.Max), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1162,10 +1173,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minimum", - Reason: fmt.Sprintf("number must be at least %g", *v), + Value: value, + Schema: schema, + SchemaField: "minimum", + Reason: fmt.Sprintf("number must be at least %g", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1179,10 +1191,11 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maximum", - Reason: fmt.Sprintf("number must be at most %g", *v), + Value: value, + Schema: schema, + SchemaField: "maximum", + Reason: fmt.Sprintf("number must be at most %g", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1199,9 +1212,10 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "multipleOf", + Value: value, + Schema: schema, + SchemaField: "multipleOf", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1247,10 +1261,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minLength", - Reason: fmt.Sprintf("minimum string length is %d", minLength), + Value: value, + Schema: schema, + SchemaField: "minLength", + Reason: fmt.Sprintf("minimum string length is %d", minLength), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1262,10 +1277,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxLength", - Reason: fmt.Sprintf("maximum string length is %d", *maxLength), + Value: value, + Schema: schema, + SchemaField: "maxLength", + Reason: fmt.Sprintf("maximum string length is %d", *maxLength), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1286,10 +1302,11 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value } if cp := schema.compiledPattern; cp != nil && !cp.MatchString(value) { err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "pattern", - Reason: fmt.Sprintf(`string doesn't match the regular expression "%s"`, schema.Pattern), + Value: value, + Schema: schema, + SchemaField: "pattern", + Reason: fmt.Sprintf(`string doesn't match the regular expression "%s"`, schema.Pattern), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1318,11 +1335,12 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value } if formatStrErr != "" || formatErr != nil { err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "format", - Reason: formatStrErr, - Origin: formatErr, + Value: value, + Schema: schema, + SchemaField: "format", + Reason: formatStrErr, + Origin: formatErr, + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1358,10 +1376,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minItems", - Reason: fmt.Sprintf("minimum number of items is %d", v), + Value: value, + Schema: schema, + SchemaField: "minItems", + Reason: fmt.Sprintf("minimum number of items is %d", v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1375,10 +1394,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxItems", - Reason: fmt.Sprintf("maximum number of items is %d", *v), + Value: value, + Schema: schema, + SchemaField: "maxItems", + Reason: fmt.Sprintf("maximum number of items is %d", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1395,10 +1415,11 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "uniqueItems", - Reason: "duplicate items found", + Value: value, + Schema: schema, + SchemaField: "uniqueItems", + Reason: "duplicate items found", + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1486,10 +1507,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "minProperties", - Reason: fmt.Sprintf("there must be at least %d properties", v), + Value: value, + Schema: schema, + SchemaField: "minProperties", + Reason: fmt.Sprintf("there must be at least %d properties", v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1503,10 +1525,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "maxProperties", - Reason: fmt.Sprintf("there must be at most %d properties", *v), + Value: value, + Schema: schema, + SchemaField: "maxProperties", + Reason: fmt.Sprintf("there must be at most %d properties", *v), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1574,10 +1597,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := &SchemaError{ - Value: value, - Schema: schema, - SchemaField: "properties", - Reason: fmt.Sprintf("property %q is unsupported", k), + Value: value, + Schema: schema, + SchemaField: "properties", + Reason: fmt.Sprintf("property %q is unsupported", k), + customizeMessageError: settings.customizeMessageError, } if !settings.multiError { return err @@ -1598,10 +1622,11 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return errSchema } err := markSchemaErrorKey(&SchemaError{ - Value: value, - Schema: schema, - SchemaField: "required", - Reason: fmt.Sprintf("property %q is missing", k), + Value: value, + Schema: schema, + SchemaField: "required", + Reason: fmt.Sprintf("property %q is missing", k), + customizeMessageError: settings.customizeMessageError, }, k) if !settings.multiError { return err @@ -1622,10 +1647,11 @@ func (schema *Schema) expectedType(settings *schemaValidationSettings, typ strin return errSchema } return &SchemaError{ - Value: typ, - Schema: schema, - SchemaField: "type", - Reason: "Field must be set to " + schema.Type + " or not be present", + Value: typ, + Schema: schema, + SchemaField: "type", + Reason: "Field must be set to " + schema.Type + " or not be present", + customizeMessageError: settings.customizeMessageError, } } @@ -1641,12 +1667,13 @@ func (schema *Schema) compilePattern() (err error) { } type SchemaError struct { - Value interface{} - reversePath []string - Schema *Schema - SchemaField string - Reason string - Origin error + Value interface{} + reversePath []string + Schema *Schema + SchemaField string + Reason string + Origin error + customizeMessageError func(err *SchemaError) string } var _ interface{ Unwrap() error } = SchemaError{} @@ -1689,6 +1716,12 @@ func (err *SchemaError) JSONPointer() []string { } func (err *SchemaError) Error() string { + if err.customizeMessageError != nil { + if msg := err.customizeMessageError(err); msg != "" { + return msg + } + } + if err.Origin != nil { return err.Origin.Error() } diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 854ae8480..5a28c8d8d 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -16,6 +16,8 @@ type schemaValidationSettings struct { onceSettingDefaults sync.Once defaultsSet func() + + customizeMessageError func(err *SchemaError) string } // FailFast returns schema validation errors quicker. @@ -50,6 +52,12 @@ func DefaultsSet(f func()) SchemaValidationOption { return func(s *schemaValidationSettings) { s.defaultsSet = f } } +// SetSchemaErrorMessageCustomizer allows to override the schema error message. +// If the passed function returns an empty string, it returns to the previous Error() implementation. +func SetSchemaErrorMessageCustomizer(f func(err *SchemaError) string) SchemaValidationOption { + return func(s *schemaValidationSettings) { s.customizeMessageError = f } +} + func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { settings := &schemaValidationSettings{} for _, opt := range opts { diff --git a/openapi3/schema_validation_settings_test.go b/openapi3/schema_validation_settings_test.go new file mode 100644 index 000000000..db52d3bdf --- /dev/null +++ b/openapi3/schema_validation_settings_test.go @@ -0,0 +1,36 @@ +package openapi3_test + +import ( + "fmt" + + "github.com/getkin/kin-openapi/openapi3" +) + +func ExampleSetSchemaErrorMessageCustomizer() { + loader := openapi3.NewLoader() + spc := ` +components: + schemas: + Something: + type: object + properties: + field: + title: Some field + type: string +`[1:] + + doc, err := loader.LoadFromData([]byte(spc)) + if err != nil { + panic(err) + } + + opt := openapi3.SetSchemaErrorMessageCustomizer(func(err *openapi3.SchemaError) string { + return fmt.Sprintf(`field "%s" should be string`, err.Schema.Title) + }) + + err = doc.Components.Schemas["Something"].Value.Properties["field"].Value.VisitJSON(123, opt) + + fmt.Println(err.Error()) + + // Output: field "Some field" should be string +} diff --git a/openapi3filter/options.go b/openapi3filter/options.go index 14843dd1b..14c35d5da 100644 --- a/openapi3filter/options.go +++ b/openapi3filter/options.go @@ -1,5 +1,7 @@ package openapi3filter +import "github.com/getkin/kin-openapi/openapi3" + // DefaultOptions do not set an AuthenticationFunc. // A spec with security schemes defined will not pass validation // unless an AuthenticationFunc is defined. @@ -25,4 +27,15 @@ type Options struct { // Indicates whether default values are set in the // request. If true, then they are not set SkipSettingDefaults bool + + customSchemaErrorFunc CustomSchemaErrorFunc +} + +// CustomSchemaErrorFunc allows for custom the schema error message. +type CustomSchemaErrorFunc func(err *openapi3.SchemaError) string + +// WithCustomSchemaErrorFunc sets a function to override the schema error message. +// If the passed function returns an empty string, it returns to the previous Error() implementation. +func (o *Options) WithCustomSchemaErrorFunc(f CustomSchemaErrorFunc) { + o.customSchemaErrorFunc = f } diff --git a/openapi3filter/options_test.go b/openapi3filter/options_test.go new file mode 100644 index 000000000..12737114d --- /dev/null +++ b/openapi3filter/options_test.go @@ -0,0 +1,83 @@ +package openapi3filter_test + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func ExampleOptions_WithCustomSchemaErrorFunc() { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /some: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + field: + title: Some field + type: integer + responses: + '200': + description: Created +` + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + if err != nil { + panic(err) + } + + err = doc.Validate(loader.Context) + if err != nil { + panic(err) + } + + router, err := gorillamux.NewRouter(doc) + if err != nil { + panic(err) + } + + opts := &openapi3filter.Options{} + + opts.WithCustomSchemaErrorFunc(func(err *openapi3.SchemaError) string { + return fmt.Sprintf(`field "%s" must be an integer`, err.Schema.Title) + }) + + req, err := http.NewRequest(http.MethodPost, "/some", strings.NewReader(`{"field":"not integer"}`)) + if err != nil { + panic(err) + } + + req.Header.Add("Content-Type", "application/json") + + route, pathParams, err := router.FindRoute(req) + if err != nil { + panic(err) + } + + validationInput := &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + Options: opts, + } + err = openapi3filter.ValidateRequest(context.Background(), validationInput) + + fmt.Println(err.Error()) + + // Output: request body has an error: doesn't match the schema: field "Some field" must be an integer +} diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 83bea98ad..4acb9ff1f 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -178,6 +178,9 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param opts = make([]openapi3.SchemaValidationOption, 0, 1) opts = append(opts, openapi3.MultiErrors()) } + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) + } if err = schema.VisitJSON(value, opts...); err != nil { return &RequestError{Input: input, Parameter: parameter, Err: err} } @@ -264,6 +267,9 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) + } // Validate JSON with the schema if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index e90b5d60e..abcbb4e9d 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -66,6 +66,9 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.customSchemaErrorFunc != nil { + opts = append(opts, openapi3.SetSchemaErrorMessageCustomizer(options.customSchemaErrorFunc)) + } headers := make([]string, 0, len(response.Headers)) for k := range response.Headers {