diff --git a/_testdata/sample.json b/_testdata/sample.json index 844098342..12e825bf5 100644 --- a/_testdata/sample.json +++ b/_testdata/sample.json @@ -927,6 +927,9 @@ "testAny": { "$ref": "#/components/schemas/AnyTest" }, + "testAnyOf": { + "$ref": "#/components/schemas/AnyOfTest" + }, "testDate": { "type": "string", "format": "date" @@ -1060,6 +1063,31 @@ } ] }, + "AnyOfTest": { + "description": "Type for testing some anyOf cases from Jaeger operator API schema", + "required": [ + "medium", + "sizeLimit" + ], + "properties": { + "medium": { + "type": "string" + }, + "sizeLimit": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ], + "pattern": "^(\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))))?$", + "x-kubernetes-int-or-string": true + } + }, + "type": "object" + }, "OneVariantHasNoUniqueFields": { "oneOf": [ { @@ -1123,6 +1151,7 @@ "properties": { "empty": {}, "any_map": { + "type": "object", "additionalProperties": true }, "any_array": { diff --git a/gen/schema_gen.go b/gen/schema_gen.go index 2e8aa9f8e..7c585bd8a 100644 --- a/gen/schema_gen.go +++ b/gen/schema_gen.go @@ -1,10 +1,6 @@ package gen import ( - "fmt" - "path" - "regexp" - "sort" "strings" "unicode" "unicode/utf8" @@ -60,13 +56,6 @@ func (g *schemaGen) generate(name string, schema *oas.Schema) (_ *ir.Type, err e name = "R" + name } - switch { - case len(schema.AnyOf) > 0: - return nil, &ErrNotImplemented{"anyOf"} - case len(schema.AllOf) > 0: - return nil, &ErrNotImplemented{"allOf"} - } - side := func(t *ir.Type) *ir.Type { if t.Schema != nil { if ref := t.Schema.Ref; ref != "" { @@ -86,6 +75,23 @@ func (g *schemaGen) generate(name string, schema *oas.Schema) (_ *ir.Type, err e return t } + switch { + case len(schema.AnyOf) > 0: + t, err := g.anyOf(name, schema) + if err != nil { + return nil, errors.Wrap(err, "anyOf") + } + return side(t), nil + case len(schema.AllOf) > 0: + return nil, &ErrNotImplemented{"allOf"} + case len(schema.OneOf) > 0: + t, err := g.oneOf(name, schema) + if err != nil { + return nil, errors.Wrap(err, "oneOf") + } + return side(t), nil + } + switch schema.Type { case oas.Object: kind := ir.KindStruct @@ -116,13 +122,15 @@ func (g *schemaGen) generate(name string, schema *oas.Schema) (_ *ir.Type, err e }) } - if schema.Item != nil { - s.Item, err = g.generate(name+"Item", schema.Item) - if err != nil { - return nil, errors.Wrap(err, "item") + if schema.AdditionalProperties { + if schema.Item != nil { + s.Item, err = g.generate(name+"Item", schema.Item) + if err != nil { + return nil, errors.Wrap(err, "item") + } + } else { + s.Item = ir.Any() } - } else { - s.Item = ir.Any() } return s, nil @@ -133,12 +141,7 @@ func (g *schemaGen) generate(name string, schema *oas.Schema) (_ *ir.Type, err e NilSemantic: ir.NilInvalid, } - if schema.MaxItems != nil { - array.Validators.Array.SetMaxLength(int(*schema.MaxItems)) - } - if schema.MinItems != nil { - array.Validators.Array.SetMinLength(int(*schema.MinItems)) - } + array.Validators.SetArray(schema) ret := side(array) if schema.Item != nil { @@ -160,352 +163,18 @@ func (g *schemaGen) generate(name string, schema *oas.Schema) (_ *ir.Type, err e switch schema.Type { case oas.String: - if schema.Pattern != "" { - t.Validators.String.Regex, err = regexp.Compile(schema.Pattern) - if err != nil { - return nil, errors.Wrap(err, "pattern") - } - } - if schema.MaxLength != nil { - t.Validators.String.SetMaxLength(int(*schema.MaxLength)) - } - if schema.MinLength != nil { - t.Validators.String.SetMinLength(int(*schema.MinLength)) - } - if schema.Format == oas.FormatEmail { - t.Validators.String.Email = true - } - if schema.Format == oas.FormatHostname { - t.Validators.String.Hostname = true + if err := t.Validators.SetString(schema); err != nil { + return nil, errors.Wrap(err, "string validator") } case oas.Integer, oas.Number: - if schema.MultipleOf != nil { - t.Validators.Int.MultipleOf = *schema.MultipleOf - t.Validators.Int.MultipleOfSet = true - } - if schema.Maximum != nil { - t.Validators.Int.Max = *schema.Maximum - t.Validators.Int.MaxSet = true - } - if schema.Minimum != nil { - t.Validators.Int.Min = *schema.Minimum - t.Validators.Int.MinSet = true - } - t.Validators.Int.MaxExclusive = schema.ExclusiveMaximum - t.Validators.Int.MinExclusive = schema.ExclusiveMinimum + t.Validators.SetInt(schema) } return side(t), nil case oas.Empty: - if len(schema.OneOf) == 0 { - return side(ir.Any()), nil - } - t, err := g.oneOf(name, schema) - if err != nil { - return nil, errors.Wrap(err, "oneOf") - } - return side(t), nil + return side(ir.Any()), nil default: panic("unreachable") } } - -func (g *schemaGen) primitive(name string, schema *oas.Schema) (*ir.Type, error) { - t, err := parseSimple(schema) - if err != nil { - return nil, err - } - - if len(schema.Enum) > 0 { - if !t.Is(ir.KindPrimitive) { - return nil, errors.Errorf("unsupported enum type: %q", schema.Type) - } - - hasDuplicateNames := func() bool { - names := map[string]struct{}{} - for _, v := range schema.Enum { - vstr := fmt.Sprintf("%v", v) - if vstr == "" { - vstr = "Empty" - } - - k := pascalSpecial(name, vstr) - if _, ok := names[k]; ok { - return true - } - names[k] = struct{}{} - } - - return false - }() - - var variants []*ir.EnumVariant - for _, v := range schema.Enum { - vstr := fmt.Sprintf("%v", v) - if vstr == "" { - vstr = "Empty" - } - - var variantName string - if hasDuplicateNames { - variantName = name + "_" + vstr - } else { - variantName = pascalSpecial(name, vstr) - } - - variants = append(variants, &ir.EnumVariant{ - Name: variantName, - Value: v, - }) - } - - return &ir.Type{ - Kind: ir.KindEnum, - Name: name, - Primitive: t.Primitive, - EnumVariants: variants, - Schema: schema, - }, nil - } - - return t, nil -} - -func parseSimple(schema *oas.Schema) (*ir.Type, error) { - mapping := map[oas.SchemaType]map[oas.Format]ir.PrimitiveType{ - oas.Integer: { - oas.FormatInt32: ir.Int32, - oas.FormatInt64: ir.Int64, - oas.FormatNone: ir.Int, - }, - oas.Number: { - oas.FormatFloat: ir.Float32, - oas.FormatDouble: ir.Float64, - oas.FormatNone: ir.Float64, - oas.FormatInt32: ir.Int32, - oas.FormatInt64: ir.Int64, - }, - oas.String: { - oas.FormatByte: ir.ByteSlice, - oas.FormatDateTime: ir.Time, - oas.FormatDate: ir.Time, - oas.FormatTime: ir.Time, - oas.FormatDuration: ir.Duration, - oas.FormatUUID: ir.UUID, - oas.FormatIP: ir.IP, - oas.FormatIPv4: ir.IP, - oas.FormatIPv6: ir.IP, - oas.FormatURI: ir.URL, - oas.FormatPassword: ir.String, - oas.FormatNone: ir.String, - }, - oas.Boolean: { - oas.FormatNone: ir.Bool, - }, - } - - t, found := mapping[schema.Type][schema.Format] - if !found { - // Return string type for unknown string formats. - if schema.Type == oas.String { - return ir.Primitive(ir.String, schema), nil - } - return nil, errors.Errorf("unexpected %q format: %q", schema.Type, schema.Format) - } - - return ir.Primitive(t, schema), nil -} - -func (g *schemaGen) oneOf(name string, schema *oas.Schema) (*ir.Type, error) { - sum := &ir.Type{ - Name: name, - Kind: ir.KindSum, - Schema: schema, - } - names := map[string]struct{}{} - for i, s := range schema.OneOf { - t, err := g.generate(fmt.Sprintf("%s%d", name, i), s) - if err != nil { - return nil, errors.Wrapf(err, "oneOf[%d]", i) - } - t.Name = variantFieldName(t) - if _, ok := names[t.Name]; ok { - return nil, errors.Wrap(&ErrNotImplemented{ - Name: "sum types with same names", - }, name) - } - names[t.Name] = struct{}{} - sum.SumOf = append(sum.SumOf, t) - } - - // 1st case: explicit discriminator. - if d := schema.Discriminator; d != nil { - sum.SumSpec.Discriminator = schema.Discriminator.PropertyName - for k, v := range schema.Discriminator.Mapping { - // Explicit mapping. - var found bool - for _, s := range sum.SumOf { - if path.Base(s.Schema.Ref) == v { - found = true - sum.SumSpec.Mapping = append(sum.SumSpec.Mapping, ir.SumSpecMap{ - Key: k, - Type: s.Name, - }) - } - } - if !found { - return nil, errors.Errorf("discriminator: unable to map %s to %s", k, v) - } - } - if len(sum.SumSpec.Mapping) == 0 { - // Implicit mapping, defaults to type name. - for _, s := range sum.SumOf { - sum.SumSpec.Mapping = append(sum.SumSpec.Mapping, ir.SumSpecMap{ - Key: path.Base(s.Schema.Ref), - Type: s.Name, - }) - } - } - sort.SliceStable(sum.SumSpec.Mapping, func(i, j int) bool { - a := sum.SumSpec.Mapping[i] - b := sum.SumSpec.Mapping[j] - return strings.Compare(a.Key, b.Key) < 0 - }) - return sum, nil - } - - // 2nd case: distinguish by serialization type. - var ( - // Collect map of variant kinds. - typeMap = map[ir.TypeDiscriminator]struct{}{} - // If all variants have different kinds, so - // we can distinguish them by JSON type. - canUseTypeDiscriminator = true - ) - for _, s := range sum.SumOf { - var kind ir.TypeDiscriminator - kind.Set(s) - if _, ok := typeMap[kind]; ok { - // Type kind is not unique, so we can distinguish variants by type. - canUseTypeDiscriminator = false - break - } - typeMap[kind] = struct{}{} - } - if canUseTypeDiscriminator { - sum.SumSpec.TypeDiscriminator = true - return sum, nil - } - - // 3rd case: distinguish by unique fields. - var ( - // Determine unique fields for each SumOf variant. - uniq = map[string]map[string]struct{}{} - ) - for _, s := range sum.SumOf { - uniq[s.Name] = map[string]struct{}{} - if !s.Is(ir.KindMap, ir.KindStruct) { - return nil, errors.Wrapf(&ErrNotImplemented{Name: "discriminator inference"}, - "oneOf %s: variant %s: no unique fields, "+ - "unable to parse without discriminator", sum.Name, s.Name, - ) - } - for _, f := range s.Fields { - uniq[s.Name][f.Name] = struct{}{} - } - } - { - // Collect fields that common for at least 2 variants. - commonFields := map[string]struct{}{} - for _, variant := range sum.SumOf { - k := variant.Name - fields := uniq[k] - for _, otherVariant := range sum.SumOf { - otherK := otherVariant.Name - if otherK == k { - continue - } - otherFields := uniq[otherK] - for otherField := range otherFields { - if _, has := fields[otherField]; has { - // variant and otherVariant have common field otherField. - commonFields[otherField] = struct{}{} - } - } - } - } - // Delete such fields. - for field := range commonFields { - for _, variant := range sum.SumOf { - delete(uniq[variant.Name], field) - } - } - - // Check that at most one type has no unique fields. - metNoUniqueFields := false - for _, variant := range sum.SumOf { - k := variant.Name - if len(uniq[k]) == 0 { - if metNoUniqueFields { - // Unable to deterministically select sub-schema only on fields. - return nil, errors.Wrapf(&ErrNotImplemented{Name: "discriminator inference"}, - "oneOf %s: variant %s: no unique fields, "+ - "unable to parse without discriminator", sum.Name, k, - ) - } - - // Set mapping without unique fields as default - sum.SumSpec.DefaultMapping = k - metNoUniqueFields = true - } - } - - } - type sumVariant struct { - Name string - Unique []string - } - var variants []sumVariant - for _, s := range sum.SumOf { - k := s.Name - f := uniq[k] - v := sumVariant{ - Name: k, - } - for fieldName := range f { - v.Unique = append(v.Unique, fieldName) - } - sort.Strings(v.Unique) - variants = append(variants, v) - } - sort.SliceStable(variants, func(i, j int) bool { - a := variants[i] - b := variants[j] - return strings.Compare(a.Name, b.Name) < 0 - }) - for _, v := range variants { - for _, s := range sum.SumOf { - if s.Name != v.Name { - continue - } - if len(s.SumSpec.Unique) > 0 { - continue - } - for _, f := range s.Fields { - var skip bool - for _, n := range v.Unique { - if n == f.Name { - skip = true // not unique - break - } - } - if !skip { - continue - } - s.SumSpec.Unique = append(s.SumSpec.Unique, f) - } - } - } - return sum, nil -} diff --git a/gen/schema_gen_primitive.go b/gen/schema_gen_primitive.go new file mode 100644 index 000000000..1ea124088 --- /dev/null +++ b/gen/schema_gen_primitive.go @@ -0,0 +1,116 @@ +package gen + +import ( + "fmt" + + "github.com/go-faster/errors" + + "github.com/ogen-go/ogen/internal/ir" + "github.com/ogen-go/ogen/internal/oas" +) + +func (g *schemaGen) primitive(name string, schema *oas.Schema) (*ir.Type, error) { + t, err := parseSimple(schema) + if err != nil { + return nil, err + } + + if len(schema.Enum) > 0 { + if !t.Is(ir.KindPrimitive) { + return nil, errors.Errorf("unsupported enum type: %q", schema.Type) + } + + hasDuplicateNames := func() bool { + names := map[string]struct{}{} + for _, v := range schema.Enum { + vstr := fmt.Sprintf("%v", v) + if vstr == "" { + vstr = "Empty" + } + + k := pascalSpecial(name, vstr) + if _, ok := names[k]; ok { + return true + } + names[k] = struct{}{} + } + + return false + }() + + var variants []*ir.EnumVariant + for _, v := range schema.Enum { + vstr := fmt.Sprintf("%v", v) + if vstr == "" { + vstr = "Empty" + } + + var variantName string + if hasDuplicateNames { + variantName = name + "_" + vstr + } else { + variantName = pascalSpecial(name, vstr) + } + + variants = append(variants, &ir.EnumVariant{ + Name: variantName, + Value: v, + }) + } + + return &ir.Type{ + Kind: ir.KindEnum, + Name: name, + Primitive: t.Primitive, + EnumVariants: variants, + Schema: schema, + }, nil + } + + return t, nil +} + +func parseSimple(schema *oas.Schema) (*ir.Type, error) { + mapping := map[oas.SchemaType]map[oas.Format]ir.PrimitiveType{ + oas.Integer: { + oas.FormatInt32: ir.Int32, + oas.FormatInt64: ir.Int64, + oas.FormatNone: ir.Int, + }, + oas.Number: { + oas.FormatFloat: ir.Float32, + oas.FormatDouble: ir.Float64, + oas.FormatNone: ir.Float64, + oas.FormatInt32: ir.Int32, + oas.FormatInt64: ir.Int64, + }, + oas.String: { + oas.FormatByte: ir.ByteSlice, + oas.FormatDateTime: ir.Time, + oas.FormatDate: ir.Time, + oas.FormatTime: ir.Time, + oas.FormatDuration: ir.Duration, + oas.FormatUUID: ir.UUID, + oas.FormatIP: ir.IP, + oas.FormatIPv4: ir.IP, + oas.FormatIPv6: ir.IP, + oas.FormatURI: ir.URL, + oas.FormatPassword: ir.String, + oas.FormatNone: ir.String, + }, + oas.Boolean: { + oas.FormatNone: ir.Bool, + }, + } + + t, found := mapping[schema.Type][schema.Format] + if !found { + // Return string type for unknown string formats. + if schema.Type == oas.String { + return ir.Primitive(ir.String, schema), nil + } + return nil, errors.Errorf("unexpected %q format: %q", schema.Type, schema.Format) + } + + return ir.Primitive(t, schema), nil +} diff --git a/gen/schema_gen_sum.go b/gen/schema_gen_sum.go new file mode 100644 index 000000000..6a885e304 --- /dev/null +++ b/gen/schema_gen_sum.go @@ -0,0 +1,260 @@ +package gen + +import ( + "fmt" + "path" + "sort" + "strings" + + "github.com/go-faster/errors" + + "github.com/ogen-go/ogen/internal/ir" + "github.com/ogen-go/ogen/internal/oas" +) + +func canUseTypeDiscriminator(sum []*ir.Type) bool { + // Collect map of variant kinds. + typeMap := map[ir.TypeDiscriminator]struct{}{} + + for _, s := range sum { + if s.IsAny() { + // Cannot make typed sum with Any. + return false + } + + var kind ir.TypeDiscriminator + kind.Set(s) + if _, ok := typeMap[kind]; ok { + // Type kind is not unique, so we can distinguish variants by type. + return false + } + typeMap[kind] = struct{}{} + } + return true +} + +func (g *schemaGen) collectSumVariants(name string, schemas []*oas.Schema) (sum []*ir.Type, _ error) { + names := map[string]struct{}{} + for i, s := range schemas { + t, err := g.generate(fmt.Sprintf("%s%d", name, i), s) + if err != nil { + return nil, errors.Wrapf(err, "oneOf[%d]", i) + } + t.Name = variantFieldName(t) + if _, ok := names[t.Name]; ok { + return nil, errors.Wrap(&ErrNotImplemented{ + Name: "sum types with same names", + }, name) + } + names[t.Name] = struct{}{} + sum = append(sum, t) + } + return sum, nil +} + +func (g *schemaGen) anyOf(name string, schema *oas.Schema) (*ir.Type, error) { + sum := &ir.Type{ + Name: name, + Kind: ir.KindSum, + Schema: schema, + } + { + variants, err := g.collectSumVariants(name, schema.AnyOf) + if err != nil { + return nil, errors.Wrap(err, "collect variants") + } + sum.SumOf = variants + } + + // Here we try to create sum type from anyOf for variants with JSON type-based discriminator. + if canUseTypeDiscriminator(sum.SumOf) { + sum.SumSpec.TypeDiscriminator = true + for _, v := range sum.SumOf { + switch v.Kind { + case ir.KindPrimitive, ir.KindEnum: + switch { + case v.IsNumeric() && !v.Validators.Int.Set(): + v.Validators.SetInt(schema) + case !v.Validators.String.Set(): + if err := v.Validators.SetString(schema); err != nil { + return nil, errors.Wrap(err, "string validator") + } + } + case ir.KindArray: + if !v.Validators.Array.Set() { + v.Validators.SetArray(schema) + } + } + } + return sum, nil + } + return nil, &ErrNotImplemented{"complex anyOf"} +} + +func (g *schemaGen) oneOf(name string, schema *oas.Schema) (*ir.Type, error) { + sum := &ir.Type{ + Name: name, + Kind: ir.KindSum, + Schema: schema, + } + { + variants, err := g.collectSumVariants(name, schema.OneOf) + if err != nil { + return nil, errors.Wrap(err, "collect variants") + } + sum.SumOf = variants + } + + // 1st case: explicit discriminator. + if d := schema.Discriminator; d != nil { + sum.SumSpec.Discriminator = schema.Discriminator.PropertyName + for k, v := range schema.Discriminator.Mapping { + // Explicit mapping. + var found bool + for _, s := range sum.SumOf { + if path.Base(s.Schema.Ref) == v { + found = true + sum.SumSpec.Mapping = append(sum.SumSpec.Mapping, ir.SumSpecMap{ + Key: k, + Type: s.Name, + }) + } + } + if !found { + return nil, errors.Errorf("discriminator: unable to map %s to %s", k, v) + } + } + if len(sum.SumSpec.Mapping) == 0 { + // Implicit mapping, defaults to type name. + for _, s := range sum.SumOf { + sum.SumSpec.Mapping = append(sum.SumSpec.Mapping, ir.SumSpecMap{ + Key: path.Base(s.Schema.Ref), + Type: s.Name, + }) + } + } + sort.SliceStable(sum.SumSpec.Mapping, func(i, j int) bool { + a := sum.SumSpec.Mapping[i] + b := sum.SumSpec.Mapping[j] + return strings.Compare(a.Key, b.Key) < 0 + }) + return sum, nil + } + + // 2nd case: distinguish by serialization type. + if canUseTypeDiscriminator(sum.SumOf) { + sum.SumSpec.TypeDiscriminator = true + return sum, nil + } + + // 3rd case: distinguish by unique fields. + var ( + // Determine unique fields for each SumOf variant. + uniq = map[string]map[string]struct{}{} + ) + for _, s := range sum.SumOf { + uniq[s.Name] = map[string]struct{}{} + if !s.Is(ir.KindMap, ir.KindStruct) { + return nil, errors.Wrapf(&ErrNotImplemented{Name: "discriminator inference"}, + "oneOf %s: variant %s: no unique fields, "+ + "unable to parse without discriminator", sum.Name, s.Name, + ) + } + for _, f := range s.Fields { + uniq[s.Name][f.Name] = struct{}{} + } + } + { + // Collect fields that common for at least 2 variants. + commonFields := map[string]struct{}{} + for _, variant := range sum.SumOf { + k := variant.Name + fields := uniq[k] + for _, otherVariant := range sum.SumOf { + otherK := otherVariant.Name + if otherK == k { + continue + } + otherFields := uniq[otherK] + for otherField := range otherFields { + if _, has := fields[otherField]; has { + // variant and otherVariant have common field otherField. + commonFields[otherField] = struct{}{} + } + } + } + } + // Delete such fields. + for field := range commonFields { + for _, variant := range sum.SumOf { + delete(uniq[variant.Name], field) + } + } + + // Check that at most one type has no unique fields. + metNoUniqueFields := false + for _, variant := range sum.SumOf { + k := variant.Name + if len(uniq[k]) == 0 { + if metNoUniqueFields { + // Unable to deterministically select sub-schema only on fields. + return nil, errors.Wrapf(&ErrNotImplemented{Name: "discriminator inference"}, + "oneOf %s: variant %s: no unique fields, "+ + "unable to parse without discriminator", sum.Name, k, + ) + } + + // Set mapping without unique fields as default + sum.SumSpec.DefaultMapping = k + metNoUniqueFields = true + } + } + + } + type sumVariant struct { + Name string + Unique []string + } + var variants []sumVariant + for _, s := range sum.SumOf { + k := s.Name + f := uniq[k] + v := sumVariant{ + Name: k, + } + for fieldName := range f { + v.Unique = append(v.Unique, fieldName) + } + sort.Strings(v.Unique) + variants = append(variants, v) + } + sort.SliceStable(variants, func(i, j int) bool { + a := variants[i] + b := variants[j] + return strings.Compare(a.Name, b.Name) < 0 + }) + for _, v := range variants { + for _, s := range sum.SumOf { + if s.Name != v.Name { + continue + } + if len(s.SumSpec.Unique) > 0 { + continue + } + for _, f := range s.Fields { + var skip bool + for _, n := range v.Unique { + if n == f.Name { + skip = true // not unique + break + } + } + if !skip { + continue + } + s.SumSpec.Unique = append(s.SumSpec.Unique, f) + } + } + } + return sum, nil +} diff --git a/gen/write.go b/gen/write.go index 2b5ee853d..a4c6ada49 100644 --- a/gen/write.go +++ b/gen/write.go @@ -50,6 +50,12 @@ func (t TemplateConfig) RegexStrings() (r []string) { for _, f := range typ.Fields { addRegex(f.Type) } + for _, f := range typ.SumOf { + addRegex(f) + } + addRegex(typ.AliasTo) + addRegex(typ.PointerTo) + addRegex(typ.Item) } for _, typ := range t.Types { diff --git a/gen_test.go b/gen_test.go index 1efc3e3f8..6b4f3d7c0 100644 --- a/gen_test.go +++ b/gen_test.go @@ -79,7 +79,7 @@ func TestGenerate(t *testing.T) { )) t.Run("GitHub", g("api.github.com.json", "complex parameter types", - "anyOf", + "complex anyOf", "allOf", "discriminator inference", "sum types with same names", diff --git a/internal/ir/validation.go b/internal/ir/validation.go index 5f6220b33..e23fb7a57 100644 --- a/internal/ir/validation.go +++ b/internal/ir/validation.go @@ -1,6 +1,11 @@ package ir import ( + "regexp" + + "github.com/go-faster/errors" + + "github.com/ogen-go/ogen/internal/oas" "github.com/ogen-go/ogen/validate" ) @@ -10,6 +15,51 @@ type Validators struct { Array validate.Array } +func (v *Validators) SetString(schema *oas.Schema) (err error) { + if schema.Pattern != "" { + v.String.Regex, err = regexp.Compile(schema.Pattern) + if err != nil { + return errors.Wrap(err, "pattern") + } + } + if schema.MaxLength != nil { + v.String.SetMaxLength(int(*schema.MaxLength)) + } + if schema.MinLength != nil { + v.String.SetMinLength(int(*schema.MinLength)) + } + if schema.Format == oas.FormatEmail { + v.String.Email = true + } + if schema.Format == oas.FormatHostname { + v.String.Hostname = true + } + return nil +} + +func (v *Validators) SetInt(schema *oas.Schema) { + if schema.MultipleOf != nil { + v.Int.SetMultipleOf(*schema.MultipleOf) + } + if schema.Maximum != nil { + v.Int.SetMaximum(*schema.Maximum) + } + if schema.Minimum != nil { + v.Int.SetMinimum(*schema.Minimum) + } + v.Int.MaxExclusive = schema.ExclusiveMaximum + v.Int.MinExclusive = schema.ExclusiveMinimum +} + +func (v *Validators) SetArray(schema *oas.Schema) { + if schema.MaxItems != nil { + v.Array.SetMaxLength(int(*schema.MaxItems)) + } + if schema.MinItems != nil { + v.Array.SetMinLength(int(*schema.MinItems)) + } +} + func (t *Type) NeedValidation() bool { return t.needValidation(&walkpath{}) } diff --git a/internal/oas/parser/schema_parser.go b/internal/oas/parser/schema_parser.go index 15b859915..742bf9b64 100644 --- a/internal/oas/parser/schema_parser.go +++ b/internal/oas/parser/schema_parser.go @@ -90,7 +90,26 @@ func (p *schemaParser) parse(schema *ogen.Schema, hook func(*oas.Schema) *oas.Sc return nil, errors.Wrapf(err, "anyOf") } - return hook(&oas.Schema{AnyOf: schemas}), nil + return hook(&oas.Schema{ + AnyOf: schemas, + // Object validators + MaxProperties: schema.MaxProperties, + MinProperties: schema.MinProperties, + // Array validators + MinItems: schema.MinItems, + MaxItems: schema.MaxItems, + UniqueItems: schema.UniqueItems, + // Number validators + Minimum: schema.Minimum, + Maximum: schema.Maximum, + ExclusiveMinimum: schema.ExclusiveMinimum, + ExclusiveMaximum: schema.ExclusiveMaximum, + MultipleOf: schema.MultipleOf, + // String validators + MaxLength: schema.MaxLength, + MinLength: schema.MinLength, + Pattern: schema.Pattern, + }), nil case len(schema.AllOf) > 0: schemas, err := p.parseMany(schema.AllOf) if err != nil { diff --git a/internal/sample_api/oas_cfg_gen.go b/internal/sample_api/oas_cfg_gen.go index 890c9ac3f..3be0571fa 100644 --- a/internal/sample_api/oas_cfg_gen.go +++ b/internal/sample_api/oas_cfg_gen.go @@ -65,6 +65,7 @@ var ( ) var regexMap = map[string]*regexp.Regexp{ + "^(\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))))?$": regexp.MustCompile("^(\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))))?$"), "^\\d-\\d$": regexp.MustCompile("^\\d-\\d$"), } diff --git a/internal/sample_api/oas_json_gen.go b/internal/sample_api/oas_json_gen.go index 16855154c..9bd23bdcb 100644 --- a/internal/sample_api/oas_json_gen.go +++ b/internal/sample_api/oas_json_gen.go @@ -64,6 +64,148 @@ var ( _ = sync.Pool{} ) +// Encode implements json.Marshaler. +func (s AnyOfTest) Encode(e *jx.Writer) { + e.ObjStart() + var ( + first = true + _ = first + ) + { + if !first { + e.Comma() + } + first = false + + e.RawStr("\"medium\"" + ":") + e.Str(s.Medium) + } + { + e.Comma() + + e.RawStr("\"sizeLimit\"" + ":") + s.SizeLimit.Encode(e) + } + e.ObjEnd() +} + +var jsonFieldsNameOfAnyOfTest = [2]string{ + 0: "medium", + 1: "sizeLimit", +} + +// Decode decodes AnyOfTest from json. +func (s *AnyOfTest) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode AnyOfTest to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "medium": + requiredBitSet[0] |= 1 << 0 + + if err := func() error { + v, err := d.Str() + s.Medium = string(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"medium\"") + } + case "sizeLimit": + requiredBitSet[0] |= 1 << 1 + + if err := func() error { + if err := s.SizeLimit.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"sizeLimit\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode AnyOfTest") + } + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000011, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfAnyOfTest) { + name = jsonFieldsNameOfAnyOfTest[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// Encode encodes AnyOfTestSizeLimit as json. +func (s AnyOfTestSizeLimit) Encode(e *jx.Writer) { + switch s.Type { + case IntAnyOfTestSizeLimit: + e.Int(s.Int) + case StringAnyOfTestSizeLimit: + e.Str(s.String) + } +} + +// Decode decodes AnyOfTestSizeLimit from json. +func (s *AnyOfTestSizeLimit) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode AnyOfTestSizeLimit to nil") + } + // Sum type type_discriminator. + switch t := d.Next(); t { + case jx.Number: + v, err := d.Int() + s.Int = int(v) + if err != nil { + return err + } + s.Type = IntAnyOfTestSizeLimit + case jx.String: + v, err := d.Str() + s.String = string(v) + if err != nil { + return err + } + s.Type = StringAnyOfTestSizeLimit + default: + return errors.Errorf("unexpected json type %q", t) + } + return nil +} + // Encode implements json.Marshaler. func (s AnyTest) Encode(e *jx.Writer) { e.ObjStart() @@ -2762,6 +2904,31 @@ func (s *OneVariantHasNoUniqueFields1) Decode(d *jx.Decoder) error { return nil } +// Encode encodes AnyOfTest as json. +func (o OptAnyOfTest) Encode(e *jx.Writer) { + if !o.Set { + return + } + o.Value.Encode(e) +} + +// Decode decodes AnyOfTest from json. +func (o *OptAnyOfTest) Decode(d *jx.Decoder) error { + if o == nil { + return errors.New("invalid: unable to decode OptAnyOfTest to nil") + } + switch d.Next() { + case jx.Object: + o.Set = true + if err := o.Value.Decode(d); err != nil { + return err + } + return nil + default: + return errors.Errorf("unexpected type %q while reading OptAnyOfTest", d.Next()) + } +} + // Encode encodes AnyTest as json. func (o OptAnyTest) Encode(e *jx.Writer) { if !o.Set { @@ -3598,6 +3765,15 @@ func (s Pet) Encode(e *jx.Writer) { s.TestAny.Encode(e) } } + { + if s.TestAnyOf.Set { + e.Comma() + } + if s.TestAnyOf.Set { + e.RawStr("\"testAnyOf\"" + ":") + s.TestAnyOf.Encode(e) + } + } { if s.TestDate.Set { e.Comma() @@ -3637,7 +3813,7 @@ func (s Pet) Encode(e *jx.Writer) { e.ObjEnd() } -var jsonFieldsNameOfPet = [28]string{ +var jsonFieldsNameOfPet = [29]string{ 0: "primary", 1: "id", 2: "unique_id", @@ -3662,10 +3838,11 @@ var jsonFieldsNameOfPet = [28]string{ 21: "testMap", 22: "testMapWithProps", 23: "testAny", - 24: "testDate", - 25: "testDuration", - 26: "testTime", - 27: "testDateTime", + 24: "testAnyOf", + 25: "testDate", + 26: "testDuration", + 27: "testTime", + 28: "testDateTime", } // Decode decodes Pet from json. @@ -3985,6 +4162,17 @@ func (s *Pet) Decode(d *jx.Decoder) error { }(); err != nil { return errors.Wrap(err, "decode field \"testAny\"") } + case "testAnyOf": + + if err := func() error { + s.TestAnyOf.Reset() + if err := s.TestAnyOf.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"testAnyOf\"") + } case "testDate": if err := func() error { diff --git a/internal/sample_api/oas_schemas_gen.go b/internal/sample_api/oas_schemas_gen.go index bd16cd128..ab190af96 100644 --- a/internal/sample_api/oas_schemas_gen.go +++ b/internal/sample_api/oas_schemas_gen.go @@ -64,6 +64,77 @@ var ( _ = sync.Pool{} ) +// Type for testing some anyOf cases from Jaeger operator API schema. +// Ref: #/components/schemas/AnyOfTest +type AnyOfTest struct { + Medium string `json:"medium"` + SizeLimit AnyOfTestSizeLimit `json:"sizeLimit"` +} + +// AnyOfTestSizeLimit represents sum type. +type AnyOfTestSizeLimit struct { + Type AnyOfTestSizeLimitType // switch on this field + Int int + String string +} + +// AnyOfTestSizeLimitType is oneOf type of AnyOfTestSizeLimit. +type AnyOfTestSizeLimitType string + +// Possible values for AnyOfTestSizeLimitType. +const ( + IntAnyOfTestSizeLimit AnyOfTestSizeLimitType = "int" + StringAnyOfTestSizeLimit AnyOfTestSizeLimitType = "string" +) + +// IsInt reports whether AnyOfTestSizeLimit is int. +func (s AnyOfTestSizeLimit) IsInt() bool { return s.Type == IntAnyOfTestSizeLimit } + +// IsString reports whether AnyOfTestSizeLimit is string. +func (s AnyOfTestSizeLimit) IsString() bool { return s.Type == StringAnyOfTestSizeLimit } + +// SetInt sets AnyOfTestSizeLimit to int. +func (s *AnyOfTestSizeLimit) SetInt(v int) { + s.Type = IntAnyOfTestSizeLimit + s.Int = v +} + +// GetInt returns int and true boolean if AnyOfTestSizeLimit is int. +func (s AnyOfTestSizeLimit) GetInt() (v int, ok bool) { + if !s.IsInt() { + return v, false + } + return s.Int, true +} + +// NewIntAnyOfTestSizeLimit returns new AnyOfTestSizeLimit from int. +func NewIntAnyOfTestSizeLimit(v int) AnyOfTestSizeLimit { + var s AnyOfTestSizeLimit + s.SetInt(v) + return s +} + +// SetString sets AnyOfTestSizeLimit to string. +func (s *AnyOfTestSizeLimit) SetString(v string) { + s.Type = StringAnyOfTestSizeLimit + s.String = v +} + +// GetString returns string and true boolean if AnyOfTestSizeLimit is string. +func (s AnyOfTestSizeLimit) GetString() (v string, ok bool) { + if !s.IsString() { + return v, false + } + return s.String, true +} + +// NewStringAnyOfTestSizeLimit returns new AnyOfTestSizeLimit from string. +func NewStringAnyOfTestSizeLimit(v string) AnyOfTestSizeLimit { + var s AnyOfTestSizeLimit + s.SetString(v) + return s +} + // Ref: #/components/schemas/AnyTest type AnyTest struct { Empty jx.Raw `json:"empty"` @@ -667,6 +738,52 @@ type OneVariantHasNoUniqueFields1 struct { // OneofBugOK is response for OneofBug operation. type OneofBugOK struct{} +// NewOptAnyOfTest returns new OptAnyOfTest with value set to v. +func NewOptAnyOfTest(v AnyOfTest) OptAnyOfTest { + return OptAnyOfTest{ + Value: v, + Set: true, + } +} + +// OptAnyOfTest is optional AnyOfTest. +type OptAnyOfTest struct { + Value AnyOfTest + Set bool +} + +// IsSet returns true if OptAnyOfTest was set. +func (o OptAnyOfTest) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptAnyOfTest) Reset() { + var v AnyOfTest + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptAnyOfTest) SetTo(v AnyOfTest) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptAnyOfTest) Get() (v AnyOfTest, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptAnyOfTest) Or(d AnyOfTest) AnyOfTest { + if v, ok := o.Get(); ok { + return v + } + return d +} + // NewOptAnyTest returns new OptAnyTest with value set to v. func NewOptAnyTest(v AnyTest) OptAnyTest { return OptAnyTest{ @@ -1770,6 +1887,7 @@ type Pet struct { TestMap OptStringStringMap `json:"testMap"` TestMapWithProps OptMapWithProperties `json:"testMapWithProps"` TestAny OptAnyTest `json:"testAny"` + TestAnyOf OptAnyOfTest `json:"testAnyOf"` TestDate OptTime `json:"testDate"` TestDuration OptDuration `json:"testDuration"` TestTime OptTime `json:"testTime"` diff --git a/internal/sample_api/oas_validators_gen.go b/internal/sample_api/oas_validators_gen.go index 97c58d41a..d0b039268 100644 --- a/internal/sample_api/oas_validators_gen.go +++ b/internal/sample_api/oas_validators_gen.go @@ -64,6 +64,46 @@ var ( _ = sync.Pool{} ) +func (s AnyOfTest) Validate() error { + var failures []validate.FieldError + if err := func() error { + if err := s.SizeLimit.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "sizeLimit", + Error: err, + }) + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + return nil +} +func (s AnyOfTestSizeLimit) Validate() error { + switch s.Type { + case IntAnyOfTestSizeLimit: + return nil // no validation needed + case StringAnyOfTestSizeLimit: + if err := (validate.String{ + MinLength: 0, + MinLengthSet: false, + MaxLength: 0, + MaxLengthSet: false, + Email: false, + Hostname: false, + Regex: regexMap["^(\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\\+|-)?(([0-9]+(\\.[0-9]*)?)|(\\.[0-9]+))))?$"], + }).Validate(string(s.String)); err != nil { + return errors.Wrap(err, "string") + } + return nil + default: + return errors.Errorf("invalid type %q", s.Type) + } +} + func (s ArrayTest) Validate() error { var failures []validate.FieldError if err := func() error { @@ -575,6 +615,25 @@ func (s Pet) Validate() error { Error: err, }) } + if err := func() error { + if s.TestAnyOf.Set { + if err := func() error { + if err := s.TestAnyOf.Value.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return err + } + } + return nil + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "testAnyOf", + Error: err, + }) + } if len(failures) > 0 { return &validate.Error{Fields: failures} } diff --git a/json_test.go b/json_test.go index 2901dd6e3..1c047ac4d 100644 --- a/json_test.go +++ b/json_test.go @@ -483,6 +483,30 @@ func TestJSONSum(t *testing.T) { }) } }) + t.Run("AnyOfTestSizeLimit", func(t *testing.T) { + for i, tc := range []struct { + Input string + Expected api.AnyOfTestSizeLimitType + Error bool + }{ + {`10`, api.IntAnyOfTestSizeLimit, false}, + {`"10"`, api.StringAnyOfTestSizeLimit, false}, + {`true`, "", true}, + {`null`, "", true}, + } { + // Make range value copy to prevent data races. + tc := tc + t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { + checker := require.NoError + if tc.Error { + checker = require.Error + } + r := api.AnyOfTestSizeLimit{} + checker(t, r.Decode(jx.DecodeStr(tc.Input))) + require.Equal(t, tc.Expected, r.Type) + }) + } + }) } func TestJSONAny(t *testing.T) { diff --git a/validate_test.go b/validate_test.go index 95134ab10..13734862f 100644 --- a/validate_test.go +++ b/validate_test.go @@ -118,3 +118,35 @@ func TestValidateMap(t *testing.T) { }) } } + +func TestValidateSum(t *testing.T) { + for i, tc := range []struct { + Input string + Error bool + }{ + { + `{"medium": "text", "sizeLimit": "aboba"}`, + true, + }, + { + `{"medium": "text", "sizeLimit": 10}`, + false, + }, + { + `{"medium": "text", "sizeLimit": "10"}`, + false, + }, + } { + tc := tc + t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { + m := api.AnyOfTest{} + require.NoError(t, m.Decode(jx.DecodeStr(tc.Input))) + + checker := require.NoError + if tc.Error { + checker = require.Error + } + checker(t, m.Validate()) + }) + } +}