From 64c435106afbc4d31f4c6fade748d6becd61cfe1 Mon Sep 17 00:00:00 2001 From: Fabian Martin Date: Mon, 12 Sep 2022 20:29:01 +0200 Subject: [PATCH] refactor: use existing code to parse new type definition - use methods from Operation to get schema - methods parseObjectSchema and parseCombinedObjectSchema disconnected from struct --- operation.go | 22 +++++++++++++++------- parser.go | 38 ++++++++------------------------------ 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/operation.go b/operation.go index 71c4e47ca..6c99903c9 100644 --- a/operation.go +++ b/operation.go @@ -824,6 +824,10 @@ var responsePattern = regexp.MustCompile(`^([\w,]+)\s+([\w{}]+)\s+([\w\-.\\{}=,\ var combinedPattern = regexp.MustCompile(`^([\w\-./\[\]]+){(.*)}$`) func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { + return parseObjectSchema(operation.parser, refType, astFile) +} + +func parseObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { switch { case refType == NIL: return nil, nil @@ -838,7 +842,7 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) case IsPrimitiveType(refType): return PrimitiveSchema(refType), nil case strings.HasPrefix(refType, "[]"): - schema, err := operation.parseObjectSchema(refType[2:], astFile) + schema, err := parseObjectSchema(parser, refType[2:], astFile) if err != nil { return nil, err } @@ -856,17 +860,17 @@ func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) return spec.MapProperty(nil), nil } - schema, err := operation.parseObjectSchema(refType, astFile) + schema, err := parseObjectSchema(parser, refType, astFile) if err != nil { return nil, err } return spec.MapProperty(schema), nil case strings.Contains(refType, "{"): - return operation.parseCombinedObjectSchema(refType, astFile) + return parseCombinedObjectSchema(parser, refType, astFile) default: - if operation.parser != nil { // checking refType has existing in 'TypeDefinitions' - schema, err := operation.parser.getTypeSchema(refType, astFile, true) + if parser != nil { // checking refType has existing in 'TypeDefinitions' + schema, err := parser.getTypeSchema(refType, astFile, true) if err != nil { return nil, err } @@ -897,12 +901,16 @@ func parseFields(s string) []string { } func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { + return parseCombinedObjectSchema(operation.parser, refType, astFile) +} + +func parseCombinedObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { matches := combinedPattern.FindStringSubmatch(refType) if len(matches) != 3 { return nil, fmt.Errorf("invalid type: %s", refType) } - schema, err := operation.parseObjectSchema(matches[1], astFile) + schema, err := parseObjectSchema(parser, matches[1], astFile) if err != nil { return nil, err } @@ -912,7 +920,7 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a for _, field := range fields { keyVal := strings.SplitN(field, "=", 2) if len(keyVal) == 2 { - schema, err := operation.parseObjectSchema(keyVal[1], astFile) + schema, err := parseObjectSchema(parser, keyVal[1], astFile) if err != nil { return nil, err } diff --git a/parser.go b/parser.go index 7af770e26..63586217a 100644 --- a/parser.go +++ b/parser.go @@ -871,22 +871,21 @@ func convertFromSpecificToPrimitive(typeName string) (string, error) { } func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) { - var arrayDepth = 0 if override, ok := parser.Overrides[typeName]; ok { parser.debug.Printf("Override detected for %s: using %s instead", typeName, override) - arrayDepth, typeName = getArrayDepth(override) + return parseObjectSchema(parser, override, file) } if IsInterfaceLike(typeName) { - return transformToArray(&spec.Schema{}, arrayDepth), nil + return &spec.Schema{}, nil } if IsGolangPrimitiveType(typeName) { - return transformToArray(PrimitiveSchema(TransToValidSchemeType(typeName)), arrayDepth), nil + return PrimitiveSchema(TransToValidSchemeType(typeName)), nil } schemaType, err := convertFromSpecificToPrimitive(typeName) if err == nil { - return transformToArray(PrimitiveSchema(schemaType), arrayDepth), nil + return PrimitiveSchema(schemaType), nil } typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency) @@ -902,18 +901,13 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override) - arrayDepth, override = getArrayDepth(override) separator := strings.LastIndex(override, ".") if separator == -1 { // treat as a swaggertype tag parts := strings.Split(override, ",") - s, err := BuildCustomSchema(parts) - if err != nil { - return nil, err - } - return transformToArray(s, arrayDepth), nil + return BuildCustomSchema(parts) } typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:]) @@ -926,7 +920,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( schema, err = parser.ParseDefinition(typeSpecDef) if err != nil { if err == ErrRecursiveParseStruct && ref { - return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil + return parser.getRefTypeSchema(typeSpecDef, schema), nil } return nil, err @@ -934,10 +928,10 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } if ref && len(schema.Schema.Type) > 0 && schema.Schema.Type[0] == OBJECT { - return transformToArray(parser.getRefTypeSchema(typeSpecDef, schema), arrayDepth), nil + return parser.getRefTypeSchema(typeSpecDef, schema), nil } - return transformToArray(schema.Schema, arrayDepth), nil + return schema.Schema, nil } func (parser *Parser) renameRefSchemas() { @@ -1613,19 +1607,3 @@ func (parser *Parser) addTestType(typename string) { Schema: PrimitiveSchema(OBJECT), } } - -func getArrayDepth(typename string) (int, string) { - var d = 0 - for strings.HasPrefix(typename, "[]") { - typename = typename[2:] - d++ - } - return d, typename -} - -func transformToArray(s *spec.Schema, arrayDepth int) *spec.Schema { - for i := 0; i < arrayDepth; i++ { - s = spec.ArrayProperty(s) - } - return s -}