From d215316d0b254671a7bbd208dc67e6a95d20af53 Mon Sep 17 00:00:00 2001 From: Fabian Martin Date: Mon, 12 Sep 2022 13:18:55 +0200 Subject: [PATCH 1/2] fix: Generic Fields does not handle Arrays in the .swaggo file - Detect array type and transform schema into spec.ArrayProperty if needed fixes https://github.com/swaggo/swag/issues/1318, https://github.com/swaggo/swag/issues/1320 --- generics_test.go | 5 ++-- parser.go | 38 +++++++++++++++++++------ testdata/generics_basic/.swaggo | 3 +- testdata/generics_basic/expected.json | 14 ++++++++- testdata/generics_basic/types/string.go | 12 ++++++-- 5 files changed, 57 insertions(+), 15 deletions(-) diff --git a/generics_test.go b/generics_test.go index 9383e05f5..089e0d7df 100644 --- a/generics_test.go +++ b/generics_test.go @@ -31,8 +31,9 @@ func TestParseGenericsBasic(t *testing.T) { p := New() p.Overrides = map[string]string{ - "types.Field[string]": "string", - "types.DoubleField[string,string]": "string", + "types.Field[string]": "string", + "types.DoubleField[string,string]": "[]string", + "types.TrippleField[string,string]": "[][]string", } err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) diff --git a/parser.go b/parser.go index a47187989..7af770e26 100644 --- a/parser.go +++ b/parser.go @@ -871,21 +871,22 @@ func convertFromSpecificToPrimitive(typeName string) (string, error) { } func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) { + var arrayDepth = 0 if override, ok := parser.Overrides[typeName]; ok { parser.debug.Printf("Override detected for %s: using %s instead", typeName, override) - typeName = override + arrayDepth, typeName = getArrayDepth(override) } if IsInterfaceLike(typeName) { - return &spec.Schema{}, nil + return transformToArray(&spec.Schema{}, arrayDepth), nil } if IsGolangPrimitiveType(typeName) { - return PrimitiveSchema(TransToValidSchemeType(typeName)), nil + return transformToArray(PrimitiveSchema(TransToValidSchemeType(typeName)), arrayDepth), nil } schemaType, err := convertFromSpecificToPrimitive(typeName) if err == nil { - return PrimitiveSchema(schemaType), nil + return transformToArray(PrimitiveSchema(schemaType), arrayDepth), nil } typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency) @@ -901,13 +902,18 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override) + arrayDepth, override = getArrayDepth(override) separator := strings.LastIndex(override, ".") if separator == -1 { // treat as a swaggertype tag parts := strings.Split(override, ",") - return BuildCustomSchema(parts) + s, err := BuildCustomSchema(parts) + if err != nil { + return nil, err + } + return transformToArray(s, arrayDepth), nil } typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:]) @@ -920,7 +926,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( schema, err = parser.ParseDefinition(typeSpecDef) if err != nil { if err == ErrRecursiveParseStruct && ref { - return parser.getRefTypeSchema(typeSpecDef, schema), nil + return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil } return nil, err @@ -928,10 +934,10 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } if ref && len(schema.Schema.Type) > 0 && schema.Schema.Type[0] == OBJECT { - return parser.getRefTypeSchema(typeSpecDef, schema), nil + return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil } - return schema.Schema, nil + return transformToArray(schema.Schema, arrayDepth), nil } func (parser *Parser) renameRefSchemas() { @@ -1607,3 +1613,19 @@ func (parser *Parser) addTestType(typename string) { Schema: PrimitiveSchema(OBJECT), } } + +func getArrayDepth(typename string) (int, string) { + var d = 0 + for strings.HasPrefix(typename, "[]") { + typename = typename[2:] + d++ + } + return d, typename +} + +func transformToArray(s *spec.Schema, arrayDepth int) *spec.Schema { + for i := 0; i < arrayDepth; i++ { + s = spec.ArrayProperty(s) + } + return s +} diff --git a/testdata/generics_basic/.swaggo b/testdata/generics_basic/.swaggo index 8c55f566d..30766d073 100644 --- a/testdata/generics_basic/.swaggo +++ b/testdata/generics_basic/.swaggo @@ -1,2 +1,3 @@ replace types.Field[string] string -replace types.DoubleField[string,string] string \ No newline at end of file +replace types.DoubleField[string,string] []string +replace types.TrippleField[string,string] [][]string \ No newline at end of file diff --git a/testdata/generics_basic/expected.json b/testdata/generics_basic/expected.json index 8124cc621..8f56bf54f 100644 --- a/testdata/generics_basic/expected.json +++ b/testdata/generics_basic/expected.json @@ -181,8 +181,20 @@ "types.Hello": { "type": "object", "properties": { + "myNewArrayField": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "string" + } + } + }, "myNewField": { - "type": "string" + "type": "array", + "items": { + "type": "string" + } }, "myStringField1": { "type": "string" diff --git a/testdata/generics_basic/types/string.go b/testdata/generics_basic/types/string.go index c7e251bc9..d90bcd1e7 100644 --- a/testdata/generics_basic/types/string.go +++ b/testdata/generics_basic/types/string.go @@ -9,8 +9,14 @@ type DoubleField[T1 any, T2 any] struct { Value2 T2 } +type TrippleField[T1 any, T2 any] struct { + Value1 T1 + Value2 T2 +} + type Hello struct { - MyStringField1 Field[*string] `json:"myStringField1"` - MyStringField2 Field[string] `json:"myStringField2"` - MyArrayField DoubleField[*string, string] `json:"myNewField"` + MyStringField1 Field[*string] `json:"myStringField1"` + MyStringField2 Field[string] `json:"myStringField2"` + MyArrayField DoubleField[*string, string] `json:"myNewField"` + MyArrayDepthField TrippleField[*string, string] `json:"myNewArrayField"` } From c9e9d4e470933ca44ff7cf48de9f6e5570e01a39 Mon Sep 17 00:00:00 2001 From: Fabian Martin Date: Mon, 12 Sep 2022 20:29:01 +0200 Subject: [PATCH 2/2] refactor: use existing code to parse new type definition - use methods from Operation to get schema - methods parseObjectSchema and parseCombinedObjectSchema disconnected from struct --- operation.go | 20 ++++++++++++-------- parser.go | 38 ++++++++------------------------------ 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/operation.go b/operation.go index 71c4e47ca..86ffcf372 100644 --- a/operation.go +++ b/operation.go @@ -824,6 +824,10 @@ var responsePattern = regexp.MustCompile(`^([\w,]+)\s+([\w{}]+)\s+([\w\-.\\{}=,\ var combinedPattern = regexp.MustCompile(`^([\w\-./\[\]]+){(.*)}$`) func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { + return parseObjectSchema(operation.parser, refType, astFile) +} + +func parseObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { switch { case refType == NIL: return nil, nil @@ -838,7 +842,7 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) case IsPrimitiveType(refType): return PrimitiveSchema(refType), nil case strings.HasPrefix(refType, "[]"): - schema, err := operation.parseObjectSchema(refType[2:], astFile) + schema, err := parseObjectSchema(parser, refType[2:], astFile) if err != nil { return nil, err } @@ -856,17 +860,17 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) return spec.MapProperty(nil), nil } - schema, err := operation.parseObjectSchema(refType, astFile) + schema, err := parseObjectSchema(parser, refType, astFile) if err != nil { return nil, err } return spec.MapProperty(schema), nil case strings.Contains(refType, "{"): - return operation.parseCombinedObjectSchema(refType, astFile) + return parseCombinedObjectSchema(parser, refType, astFile) default: - if operation.parser != nil { // checking refType has existing in 'TypeDefinitions' - schema, err := operation.parser.getTypeSchema(refType, astFile, true) + if parser != nil { // checking refType has existing in 'TypeDefinitions' + schema, err := parser.getTypeSchema(refType, astFile, true) if err != nil { return nil, err } @@ -896,13 +900,13 @@ func parseFields(s string) []string { }) } -func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { +func parseCombinedObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { matches := combinedPattern.FindStringSubmatch(refType) if len(matches) != 3 { return nil, fmt.Errorf("invalid type: %s", refType) } - schema, err := operation.parseObjectSchema(matches[1], astFile) + schema, err := parseObjectSchema(parser, matches[1], astFile) if err != nil { return nil, err } @@ -912,7 +916,7 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a for _, field := range fields { keyVal := strings.SplitN(field, "=", 2) if len(keyVal) == 2 { - schema, err := operation.parseObjectSchema(keyVal[1], astFile) + schema, err := parseObjectSchema(parser, keyVal[1], astFile) if err != nil { return nil, err } diff --git a/parser.go b/parser.go index 7af770e26..63586217a 100644 --- a/parser.go +++ b/parser.go @@ -871,22 +871,21 @@ func convertFromSpecificToPrimitive(typeName string) (string, error) { } func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) { - var arrayDepth = 0 if override, ok := parser.Overrides[typeName]; ok { parser.debug.Printf("Override detected for %s: using %s instead", typeName, override) - arrayDepth, typeName = getArrayDepth(override) + return parseObjectSchema(parser, override, file) } if IsInterfaceLike(typeName) { - return transformToArray(&spec.Schema{}, arrayDepth), nil + return &spec.Schema{}, nil } if IsGolangPrimitiveType(typeName) { - return transformToArray(PrimitiveSchema(TransToValidSchemeType(typeName)), arrayDepth), nil + return PrimitiveSchema(TransToValidSchemeType(typeName)), nil } schemaType, err := convertFromSpecificToPrimitive(typeName) if err == nil { - return transformToArray(PrimitiveSchema(schemaType), arrayDepth), nil + return PrimitiveSchema(schemaType), nil } typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency) @@ -902,18 +901,13 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override) - arrayDepth, override = getArrayDepth(override) separator := strings.LastIndex(override, ".") if separator == -1 { // treat as a swaggertype tag parts := strings.Split(override, ",") - s, err := BuildCustomSchema(parts) - if err != nil { - return nil, err - } - return transformToArray(s, arrayDepth), nil + return BuildCustomSchema(parts) } typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:]) @@ -926,7 +920,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( schema, err = parser.ParseDefinition(typeSpecDef) if err != nil { if err == ErrRecursiveParseStruct && ref { - return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil + return parser.getRefTypeSchema(typeSpecDef, schema), nil } return nil, err @@ -934,10 +928,10 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } if ref && len(schema.Schema.Type) > 0 && schema.Schema.Type[0] == OBJECT { - return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil + return parser.getRefTypeSchema(typeSpecDef, schema), nil } - return transformToArray(schema.Schema, arrayDepth), nil + return schema.Schema, nil } func (parser *Parser) renameRefSchemas() { @@ -1613,19 +1607,3 @@ func (parser *Parser) addTestType(typename string) { Schema: PrimitiveSchema(OBJECT), } } - -func getArrayDepth(typename string) (int, string) { - var d = 0 - for strings.HasPrefix(typename, "[]") { - typename = typename[2:] - d++ - } - return d, typename -} - -func transformToArray(s *spec.Schema, arrayDepth int) *spec.Schema { - for i := 0; i < arrayDepth; i++ { - s = spec.ArrayProperty(s) - } - return s -}