Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement new field: x-go-type-import #633

Merged
merged 15 commits into from Jun 29, 2022
54 changes: 53 additions & 1 deletion README.md
Expand Up @@ -533,7 +533,59 @@ which help you to use the various OpenAPI 3 Authentication mechanism.
```
Name string `json:"name" tag1:"value1" tag2:"value2"`
```

- `x-go-type-import`: adds extra Go imports to your generated code. It can help you, when you want to
choose your own import package for `x-go-type`.

```yaml
schemas:
Pet:
properties:
age:
x-go-type: myuuid.UUID
x-go-type-import:
name: myuuid
path: github.com/google/uuid
```
After code generation you will get this:
```go
import (
...
myuuid "github.com/google/uuid"
)

//Pet defines model for Pet.
type Pet struct {
Age *myuuid.UUID `json:"age,omitempty"`
}

```
`name` is an optional parameter. Example:

```yaml
components:
schemas:
Pet:
properties:
age:
x-go-type: uuid.UUID
x-go-type-import:
path: github.com/google/uuid
required:
- age
```

After code generation you will get this result:

```go
import (
"github.com/google/uuid"
)

// Pet defines model for Pet.
type Pet struct {
Age uuid.UUID `json:"age"`
}
```


## Using `oapi-codegen`
Expand Down
101 changes: 100 additions & 1 deletion pkg/codegen/codegen.go
Expand Up @@ -139,6 +139,7 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) {
}

var typeDefinitions, constantDefinitions string
var xGoTypeImports map[string]goImport
if opts.Generate.Models {
typeDefinitions, err = GenerateTypeDefinitions(t, spec, ops, opts.OutputOptions.ExcludeSchemas)
if err != nil {
Expand All @@ -150,6 +151,10 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) {
return "", fmt.Errorf("error generating constants: %w", err)
}

xGoTypeImports, err = GetTypeDefinitionsImports(spec, opts.OutputOptions.ExcludeSchemas)
if err != nil {
return "", fmt.Errorf("error getting type definition imports: %w", err)
}
}

var echoServerOut string
Expand Down Expand Up @@ -211,7 +216,7 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)

externalImports := importMapping.GoImports()
externalImports := append(importMapping.GoImports(), importMap(xGoTypeImports).GoImports()...)
importsOut, err := GenerateImports(t, externalImports, opts.PackageName)
if err != nil {
return "", fmt.Errorf("error generating imports: %w", err)
Expand Down Expand Up @@ -763,3 +768,97 @@ func LoadTemplates(src embed.FS, t *template.Template) error {
return nil
})
}

func GetTypeDefinitionsImports(swagger *openapi3.T, excludeSchemas []string) (map[string]goImport, error) {
res := map[string]goImport{}
schemaImports, err := GetSchemaImports(swagger.Components.Schemas, excludeSchemas)
if err != nil {
return nil, err
}

reqBodiesImports, err := GetRequestBodiesImports(swagger.Components.RequestBodies)
if err != nil {
return nil, err
}

responsesImports, err := GetResponsesImports(swagger.Components.Responses)
if err != nil {
return nil, err
}

for _, imprts := range []map[string]goImport{schemaImports, reqBodiesImports, responsesImports} {
for k, v := range imprts {
res[k] = v
}
}
return res, nil
}

func GetSchemaImports(schemas map[string]*openapi3.SchemaRef, excludeSchemas []string) (map[string]goImport, error) {
var err error
res := map[string]goImport{}
excludeSchemasMap := make(map[string]bool)
for _, schema := range excludeSchemas {
excludeSchemasMap[schema] = true
}
for _, schemaName := range SortedSchemaKeys(schemas) {
if _, ok := excludeSchemasMap[schemaName]; ok {
continue
}
schema := schemas[schemaName].Value

if schema == nil || schema.Properties == nil {
continue
}

res, err = GetImports(schema.Properties)
if err != nil {
return nil, err
}
}
return res, nil
}

func GetRequestBodiesImports(bodies map[string]*openapi3.RequestBodyRef) (map[string]goImport, error) {
var res map[string]goImport
var err error
for _, requestBodyName := range SortedRequestBodyKeys(bodies) {
requestBodyRef := bodies[requestBodyName]
response := requestBodyRef.Value
jsonBody, found := response.Content["application/json"]
if found {
schema := jsonBody.Schema
if schema == nil || schema.Value == nil || schema.Value.Properties == nil {
continue
}

res, err = GetImports(schema.Value.Properties)
if err != nil {
return nil, err
}
}
}
return res, nil
}

func GetResponsesImports(responses map[string]*openapi3.ResponseRef) (map[string]goImport, error) {
var res map[string]goImport
var err error
for _, responseName := range SortedResponsesKeys(responses) {
responseOrRef := responses[responseName]
response := responseOrRef.Value
jsonResponse, found := response.Content["application/json"]
if found {
schema := jsonResponse.Schema
if schema == nil || schema.Value == nil || schema.Value.Properties == nil {
continue
}

res, err = GetImports(schema.Value.Properties)
if err != nil {
return nil, err
}
}
}
return res, nil
}
54 changes: 47 additions & 7 deletions pkg/codegen/codegen_test.go
Expand Up @@ -8,6 +8,10 @@ import (
"net/http"
"testing"

"github.com/stretchr/testify/require"

"github.com/deepmap/oapi-codegen/pkg/util"

"github.com/getkin/kin-openapi/openapi3"
"github.com/golangci/lint-1"
"github.com/stretchr/testify/assert"
Expand All @@ -16,6 +20,13 @@ import (
examplePetstore "github.com/deepmap/oapi-codegen/examples/petstore-expanded/echo/api"
)

func checkLint(t *testing.T, filename string, code []byte) {
linter := new(lint.Linter)
problems, err := linter.Lint("test.gen.go", code)
assert.NoError(t, err)
assert.Len(t, problems, 0)
}

func TestExamplePetStoreCodeGeneration(t *testing.T) {

// Input vars for code generation:
Expand Down Expand Up @@ -58,10 +69,7 @@ func TestExamplePetStoreCodeGeneration(t *testing.T) {
`)

// Make sure the generated code is valid:
linter := new(lint.Linter)
problems, err := linter.Lint("test.gen.go", []byte(code))
assert.NoError(t, err)
assert.Len(t, problems, 0)
checkLint(t, "test.gen.go", []byte(code))
}

func TestExamplePetStoreCodeGenerationWithUserTemplates(t *testing.T) {
Expand Down Expand Up @@ -179,10 +187,42 @@ type GetTestByNameResponse struct {
assert.Contains(t, code, "DeadSince *time.Time `json:\"dead_since,omitempty\" tag1:\"value1\" tag2:\"value2\"`")

// Make sure the generated code is valid:
linter := new(lint.Linter)
problems, err := linter.Lint("test.gen.go", []byte(code))
checkLint(t, "test.gen.go", []byte(code))
}

func TestXGoTypeImport(t *testing.T) {
packageName := "api"
opts := Configuration{
PackageName: packageName,
Generate: GenerateOptions{
Models: true,
},
}
spec := "test_specs/x-go-type-import-pet.yaml"
swagger, err := util.LoadSwagger(spec)
require.NoError(t, err)

// Run our code generation:
code, err := Generate(swagger, opts)
assert.NoError(t, err)
assert.Len(t, problems, 0)
assert.NotEmpty(t, code)

// Check that we have valid (formattable) code:
_, err = format.Source([]byte(code))
assert.NoError(t, err)

// Check that we have a package:
assert.Contains(t, code, "package api")

// Check import
assert.Contains(t, code, `myuuid "github.com/google/uuid"`)

// Check generated struct
assert.Contains(t, code, "type Pet struct {\n\tAge myuuid.UUID `json:\"age\"`\n}")

// Make sure the generated code is valid:
checkLint(t, "test.gen.go", []byte(code))

}

//go:embed test_spec.yaml
Expand Down
47 changes: 47 additions & 0 deletions pkg/codegen/test_specs/x-go-type-import-pet.yaml
@@ -0,0 +1,47 @@
openapi: "3.0.0"
info:
version: 1.0.0
title: Swagger Petstore
description: A sample API that uses a petstore as an example to demonstrate features in the OpenAPI 3.0 specification
termsOfService: http://swagger.io/terms/
contact:
name: Swagger API Team
email: apiteam@swagger.io
url: http://swagger.io
license:
name: Apache 2.0
url: https://www.apache.org/licenses/LICENSE-2.0.html
servers:
- url: http://petstore.swagger.io/api
paths:
/pets/{id}:
get:
summary: Returns a pet by ID
description: Returns a pet based on a single ID
operationId: findPetByID
parameters:
- name: id
in: path
description: ID of pet to fetch
required: true
schema:
type: integer
format: int64
responses:
'200':
description: pet response
content:
application/json:
schema:
$ref: '#/components/schemas/Pet'
components:
schemas:
Pet:
properties:
age:
x-go-type: myuuid.UUID
x-go-type-import:
path: github.com/google/uuid
name: myuuid
required:
- age
26 changes: 26 additions & 0 deletions pkg/codegen/utils.go
Expand Up @@ -14,6 +14,7 @@
package codegen

import (
"encoding/json"
"fmt"
"net/url"
"regexp"
Expand Down Expand Up @@ -790,3 +791,28 @@ func findSchemaNameByRefPath(refPath string, spec *openapi3.T) (string, error) {
}
return "", nil
}

func GetImports(dict map[string]*openapi3.SchemaRef) (map[string]goImport, error) {
res := map[string]goImport{}
for _, v := range dict {
if v == nil || v.Value == nil {
continue
}

if v.Value.Extensions["x-go-type-import"] == nil || v.Value.Extensions["x-go-type"] == nil {
continue
}
goTypeImportExt := v.Value.Extensions["x-go-type-import"]

if raw, ok := goTypeImportExt.(json.RawMessage); ok {
gi := goImport{}
if err := json.Unmarshal(raw, &gi); err != nil {
return nil, err
}
res[gi.String()] = gi
} else {
continue
}
}
return res, nil
}