Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic generics support #1225

Merged
merged 3 commits into from Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
@@ -1,7 +1,7 @@
# Dockerfile References: https://docs.docker.com/engine/reference/builder/

# Start from the latest golang base image
FROM golang:1.17-alpine as builder
FROM golang:1.18.3-alpine as builder

# Set the Current Working Directory inside the container
WORKDIR /app
Expand Down
109 changes: 109 additions & 0 deletions generics.go
@@ -0,0 +1,109 @@
//go:build go1.18
// +build go1.18

package swag

import (
"go/ast"
"strings"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

if typeSpecDef.TypeSpec.TypeParams != nil {
fullName = fullName + "["
for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List {
if i > 0 {
fullName = fullName + "-"
}

fullName = fullName + typeParam.Names[0].Name
}
fullName = fullName + "]"
}

return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[")
if len(genericParams) == 1 {
return nil
}

genericParams = strings.Split(genericParams[1], ",")
for i, p := range genericParams {
genericParams[i] = strings.TrimSpace(p)
}
genericParamTypeDefs := map[string]*TypeSpecDef{}

if len(genericParams) != len(original.TypeSpec.TypeParams.List) {
return nil
}

for i, genericParam := range genericParams {
tdef, ok := pkgDefs.uniqueDefinitions[genericParam]
if !ok {
return nil
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef
}

parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Doc: original.TypeSpec.Doc,
Comment: original.TypeSpec.Comment,
Assign: original.TypeSpec.Assign,
},
}

ident := &ast.Ident{
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
}

genNameParts := strings.Split(fullGenericForm, "[")
if strings.Contains(genNameParts[0], ".") {
genNameParts[0] = strings.Split(genNameParts[0], ".")[1]
}

ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1)
ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1)

parametrizedTypeSpec.TypeSpec.Name = ident

origStructType := original.TypeSpec.Type.(*ast.StructType)

newStructTypeDef := &ast.StructType{
Struct: origStructType.Struct,
Incomplete: origStructType.Incomplete,
Fields: &ast.FieldList{
Opening: origStructType.Fields.Opening,
Closing: origStructType.Fields.Closing,
},
}

for _, field := range origStructType.Fields.List {
newField := &ast.Field{
Doc: field.Doc,
Names: field.Names,
Tag: field.Tag,
Comment: field.Comment,
}
if genTypeSpec, ok := genericParamTypeDefs[field.Type.(*ast.Ident).Name]; ok {
newField.Type = genTypeSpec.TypeSpec.Type
} else {
newField.Type = field.Type
}

newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}

parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef

return parametrizedTypeSpec
}
12 changes: 12 additions & 0 deletions generics_other.go
@@ -0,0 +1,12 @@
//go:build !go1.18
// +build !go1.18

package swag

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
return original
}
210 changes: 210 additions & 0 deletions generics_test.go
@@ -0,0 +1,210 @@
//go:build go1.18
// +build go1.18

package swag

import (
"encoding/json"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseGenericsBasic(t *testing.T) {
t.Parallel()

expected := `{
"swagger": "2.0",
"info": {
"description": "This is a sample server Petstore server.",
"title": "Swagger Example API",
"contact": {},
"version": "1.0"
},
"host": "localhost:4000",
"basePath": "/api",
"paths": {
"/posts/{post_id}": {
"get": {
"description": "get string by ID",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"summary": "Add a new pet to the store",
"parameters": [
{
"type": "integer",
"format": "int64",
"description": "Some ID",
"name": "post_id",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/web.GenericResponse-web_Post"
}
},
"222": {
"description": "",
"schema": {
"$ref": "#/definitions/web.GenericResponseMulti-web_Post-web_Post"
}
},
"400": {
"description": "We need ID!!",
"schema": {
"$ref": "#/definitions/web.APIError"
}
},
"404": {
"description": "Can not find ID",
"schema": {
"$ref": "#/definitions/web.APIError"
}
}
}
}
}
},
"definitions": {
"web.APIError": {
"description": "API error with information about it",
"type": "object",
"properties": {
"createdAt": {
"description": "Error time",
"type": "string"
},
"error": {
"description": "Error an Api error",
"type": "string"
},
"errorCtx": {
"description": "Error ` + "`context`" + ` tick comment",
"type": "string"
},
"errorNo": {
"description": "Error ` + "`number`" + ` tick comment",
"type": "integer"
}
}
},
"web.GenericResponse-web_Post": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"status": {
"type": "string"
}
}
},
"web.GenericResponseMulti-web_Post-web_Post": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"meta": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"status": {
"type": "string"
}
}
}
}
}`

searchDir := "testdata/generics_basic"
p := New()
err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, _ := json.MarshalIndent(p.swagger, "", " ")
os.WriteFile("testdata/generics_basic/swagger.json", b, 0644)
assert.Equal(t, expected, string(b))
}
15 changes: 12 additions & 3 deletions operation.go
Expand Up @@ -385,7 +385,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if objectType == PRIMITIVE {
param.Schema = PrimitiveSchema(refType)
} else {
schema, err := operation.parseAPIObjectSchema(objectType, refType, astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, objectType, refType, astFile)
if err != nil {
return err
}
Expand Down Expand Up @@ -933,7 +933,16 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a
}), nil
}

func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
func (operation *Operation) parseAPIObjectSchema(commentLine, schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
if strings.HasSuffix(refType, ",") && strings.Contains(refType, "[") {
// regexp may have broken generics syntax. find closing bracket and add it back
allMatchesLenOffset := strings.Index(commentLine, refType) + len(refType)
lostPartEndIdx := strings.Index(commentLine[allMatchesLenOffset:], "]")
if lostPartEndIdx >= 0 {
refType += commentLine[allMatchesLenOffset : allMatchesLenOffset+lostPartEndIdx+1]
}
}

switch schemaType {
case OBJECT:
if !strings.HasPrefix(refType, "[]") {
Expand Down Expand Up @@ -969,7 +978,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as

description := strings.Trim(matches[4], "\"")

schema, err := operation.parseAPIObjectSchema(strings.Trim(matches[2], "{}"), matches[3], astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, strings.Trim(matches[2], "{}"), matches[3], astFile)
if err != nil {
return err
}
Expand Down