diff --git a/Dockerfile b/Dockerfile index 65746bdde..170d0c699 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # Dockerfile References: https://docs.docker.com/engine/reference/builder/ # Start from the latest golang base image -FROM golang:1.17-alpine as builder +FROM golang:1.18.3-alpine as builder # Set the Current Working Directory inside the container WORKDIR /app diff --git a/generics.go b/generics.go new file mode 100644 index 000000000..7a7ad0207 --- /dev/null +++ b/generics.go @@ -0,0 +1,109 @@ +//go:build go1.18 +// +build go1.18 + +package swag + +import ( + "go/ast" + "strings" +) + +func typeSpecFullName(typeSpecDef *TypeSpecDef) string { + fullName := typeSpecDef.FullName() + + if typeSpecDef.TypeSpec.TypeParams != nil { + fullName = fullName + "[" + for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List { + if i > 0 { + fullName = fullName + "-" + } + + fullName = fullName + typeParam.Names[0].Name + } + fullName = fullName + "]" + } + + return fullName +} + +func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef { + genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[") + if len(genericParams) == 1 { + return nil + } + + genericParams = strings.Split(genericParams[1], ",") + for i, p := range genericParams { + genericParams[i] = strings.TrimSpace(p) + } + genericParamTypeDefs := map[string]*TypeSpecDef{} + + if len(genericParams) != len(original.TypeSpec.TypeParams.List) { + return nil + } + + for i, genericParam := range genericParams { + tdef, ok := pkgDefs.uniqueDefinitions[genericParam] + if !ok { + return nil + } + + genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef + } + + parametrizedTypeSpec := &TypeSpecDef{ + File: original.File, + PkgPath: original.PkgPath, + TypeSpec: &ast.TypeSpec{ + Doc: original.TypeSpec.Doc, + Comment: original.TypeSpec.Comment, + Assign: original.TypeSpec.Assign, + }, + } + + ident := &ast.Ident{ + NamePos: original.TypeSpec.Name.NamePos, + Obj: original.TypeSpec.Name.Obj, + } + + genNameParts := strings.Split(fullGenericForm, "[") + if strings.Contains(genNameParts[0], ".") { + genNameParts[0] = strings.Split(genNameParts[0], ".")[1] + } + + ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1) + ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1) + + parametrizedTypeSpec.TypeSpec.Name = ident + + origStructType := original.TypeSpec.Type.(*ast.StructType) + + newStructTypeDef := &ast.StructType{ + Struct: origStructType.Struct, + Incomplete: origStructType.Incomplete, + Fields: &ast.FieldList{ + Opening: origStructType.Fields.Opening, + Closing: origStructType.Fields.Closing, + }, + } + + for _, field := range origStructType.Fields.List { + newField := &ast.Field{ + Doc: field.Doc, + Names: field.Names, + Tag: field.Tag, + Comment: field.Comment, + } + if genTypeSpec, ok := genericParamTypeDefs[field.Type.(*ast.Ident).Name]; ok { + newField.Type = genTypeSpec.TypeSpec.Type + } else { + newField.Type = field.Type + } + + newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField) + } + + parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef + + return parametrizedTypeSpec +} diff --git a/generics_other.go b/generics_other.go new file mode 100644 index 000000000..a695023b1 --- /dev/null +++ b/generics_other.go @@ -0,0 +1,12 @@ +//go:build !go1.18 +// +build !go1.18 + +package swag + +func typeSpecFullName(typeSpecDef *TypeSpecDef) string { + return typeSpecDef.FullName() +} + +func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef { + return original +} diff --git a/generics_test.go b/generics_test.go new file mode 100644 index 000000000..d1dc1164e --- /dev/null +++ b/generics_test.go @@ -0,0 +1,210 @@ +//go:build go1.18 +// +build go1.18 + +package swag + +import ( + "encoding/json" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseGenericsBasic(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "contact": {}, + "version": "1.0" + }, + "host": "localhost:4000", + "basePath": "/api", + "paths": { + "/posts/{post_id}": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Add a new pet to the store", + "parameters": [ + { + "type": "integer", + "format": "int64", + "description": "Some ID", + "name": "post_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/web.GenericResponse-web_Post" + } + }, + "222": { + "description": "", + "schema": { + "$ref": "#/definitions/web.GenericResponseMulti-web_Post-web_Post" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + } + }, + "definitions": { + "web.APIError": { + "description": "API error with information about it", + "type": "object", + "properties": { + "createdAt": { + "description": "Error time", + "type": "string" + }, + "error": { + "description": "Error an Api error", + "type": "string" + }, + "errorCtx": { + "description": "Error ` + "`context`" + ` tick comment", + "type": "string" + }, + "errorNo": { + "description": "Error ` + "`number`" + ` tick comment", + "type": "integer" + } + } + }, + "web.GenericResponse-web_Post": { + "type": "object", + "properties": { + "data": { + "type": "object", + "properties": { + "data": { + "description": "Post data", + "type": "object", + "properties": { + "name": { + "description": "Post tag", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "name": { + "description": "Post name", + "type": "string", + "example": "poti" + } + } + }, + "status": { + "type": "string" + } + } + }, + "web.GenericResponseMulti-web_Post-web_Post": { + "type": "object", + "properties": { + "data": { + "type": "object", + "properties": { + "data": { + "description": "Post data", + "type": "object", + "properties": { + "name": { + "description": "Post tag", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "name": { + "description": "Post name", + "type": "string", + "example": "poti" + } + } + }, + "meta": { + "type": "object", + "properties": { + "data": { + "description": "Post data", + "type": "object", + "properties": { + "name": { + "description": "Post tag", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "name": { + "description": "Post name", + "type": "string", + "example": "poti" + } + } + }, + "status": { + "type": "string" + } + } + } + } +}` + + searchDir := "testdata/generics_basic" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + os.WriteFile("testdata/generics_basic/swagger.json", b, 0644) + assert.Equal(t, expected, string(b)) +} diff --git a/operation.go b/operation.go index e8675917d..20e7cefc5 100644 --- a/operation.go +++ b/operation.go @@ -385,7 +385,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F if objectType == PRIMITIVE { param.Schema = PrimitiveSchema(refType) } else { - schema, err := operation.parseAPIObjectSchema(objectType, refType, astFile) + schema, err := operation.parseAPIObjectSchema(commentLine, objectType, refType, astFile) if err != nil { return err } @@ -933,7 +933,16 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a }), nil } -func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) { +func (operation *Operation) parseAPIObjectSchema(commentLine, schemaType, refType string, astFile *ast.File) (*spec.Schema, error) { + if strings.HasSuffix(refType, ",") && strings.Contains(refType, "[") { + // regexp may have broken generics syntax. find closing bracket and add it back + allMatchesLenOffset := strings.Index(commentLine, refType) + len(refType) + lostPartEndIdx := strings.Index(commentLine[allMatchesLenOffset:], "]") + if lostPartEndIdx >= 0 { + refType += commentLine[allMatchesLenOffset : allMatchesLenOffset+lostPartEndIdx+1] + } + } + switch schemaType { case OBJECT: if !strings.HasPrefix(refType, "[]") { @@ -969,7 +978,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as description := strings.Trim(matches[4], "\"") - schema, err := operation.parseAPIObjectSchema(strings.Trim(matches[2], "{}"), matches[3], astFile) + schema, err := operation.parseAPIObjectSchema(commentLine, strings.Trim(matches[2], "{}"), matches[3], astFile) if err != nil { return err } diff --git a/packages.go b/packages.go index fc7a6bd2b..445d09576 100644 --- a/packages.go +++ b/packages.go @@ -134,7 +134,8 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef) } - fullName := typeSpecDef.FullName() + fullName := typeSpecFullName(typeSpecDef) + anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName] if ok { if typeSpecDef.PkgPath == anotherTypeDef.PkgPath { @@ -292,7 +293,7 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File return pkgDefs.uniqueDefinitions[typeName] } - parts := strings.Split(typeName, ".") + parts := strings.Split(strings.Split(typeName, "[")[0], ".") if len(parts) > 1 { isAliasPkgName := func(file *ast.File, pkgName string) bool { if file != nil && file.Imports != nil { @@ -328,6 +329,22 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File } } + if strings.Contains(typeName, "[") { + // joinedParts differs from typeName in that it does not contain any type parameters + joinedParts := strings.Join(parts, ".") + for tName, tSpec := range pkgDefs.uniqueDefinitions { + if !strings.Contains(tName, "[") { + continue + } + + if strings.Contains(tName, joinedParts) { + if parametrized := pkgDefs.parametrizeStruct(tSpec, typeName); parametrized != nil { + return parametrized + } + } + } + } + return pkgDefs.findTypeSpec(pkgPath, parts[1]) } diff --git a/testdata/generics_basic/api/api.go b/testdata/generics_basic/api/api.go new file mode 100644 index 000000000..3881f2230 --- /dev/null +++ b/testdata/generics_basic/api/api.go @@ -0,0 +1,22 @@ +package api + +import ( + "net/http" + + "github.com/swaggo/swag/testdata/generics_basic/web" +) + +// @Summary Add a new pet to the store +// @Description get string by ID +// @Accept json +// @Produce json +// @Param post_id path int true "Some ID" Format(int64) +// @Success 200 {object} web.GenericResponse[web.Post] +// @Success 222 {object} web.GenericResponseMulti[web.Post, web.Post] +// @Failure 400 {object} web.APIError "We need ID!!" +// @Failure 404 {object} web.APIError "Can not find ID" +// @Router /posts/{post_id} [get] +func GetPost(w http.ResponseWriter, r *http.Request) { + //write your code + _ = web.GenericResponse[web.Post]{} +} diff --git a/testdata/generics_basic/main.go b/testdata/generics_basic/main.go new file mode 100644 index 000000000..cff47d013 --- /dev/null +++ b/testdata/generics_basic/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "net/http" + + "github.com/swaggo/swag/testdata/generics_basic/api" +) + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server Petstore server. +// @host localhost:4000 +// @basePath /api +func main() { + http.HandleFunc("/posts/", api.GetPost) + http.ListenAndServe(":8080", nil) +} diff --git a/testdata/generics_basic/web/handler.go b/testdata/generics_basic/web/handler.go new file mode 100644 index 000000000..011910708 --- /dev/null +++ b/testdata/generics_basic/web/handler.go @@ -0,0 +1,42 @@ +package web + +import ( + "time" +) + +type GenericResponse[T any] struct { + Data T + + Status string +} + +type GenericResponseMulti[T any, X any] struct { + Data T + Meta X + + Status string +} + +type Post struct { + ID int `json:"id" example:"1" format:"int64"` + // Post name + Name string `json:"name" example:"poti"` + // Post data + Data struct { + // Post tag + Tag []string `json:"name"` + } `json:"data"` +} + +// APIError +// @Description API error +// @Description with information about it +// Other some summary +type APIError struct { + // Error an Api error + Error string // Error this is Line comment + // Error `number` tick comment + ErrorNo int64 + ErrorCtx string // Error `context` tick comment + CreatedAt time.Time // Error time +}