diff --git a/README.md b/README.md index 07224f8c5..58fe1c93c 100644 --- a/README.md +++ b/README.md @@ -574,7 +574,59 @@ which help you to use the various OpenAPI 3 Authentication mechanism. ``` Name string `json:"name" tag1:"value1" tag2:"value2"` ``` - +- `x-go-type-import`: adds extra Go imports to your generated code. It can help you, when you want to + choose your own import package for `x-go-type`. + + ```yaml + schemas: + Pet: + properties: + age: + x-go-type: myuuid.UUID + x-go-type-import: + name: myuuid + path: github.com/google/uuid + ``` + After code generation you will get this: + ```go + import ( + ... + myuuid "github.com/google/uuid" + ) + + //Pet defines model for Pet. + type Pet struct { + Age *myuuid.UUID `json:"age,omitempty"` + } + + ``` + `name` is an optional parameter. Example: + + ```yaml + components: + schemas: + Pet: + properties: + age: + x-go-type: uuid.UUID + x-go-type-import: + path: github.com/google/uuid + required: + - age + ``` + + After code generation you will get this result: + + ```go + import ( + "github.com/google/uuid" + ) + + // Pet defines model for Pet. + type Pet struct { + Age uuid.UUID `json:"age"` + } + ``` ## Using `oapi-codegen` diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index e87a2e1cc..6d9f345b3 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -139,6 +139,7 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) { } var typeDefinitions, constantDefinitions string + var xGoTypeImports map[string]goImport if opts.Generate.Models { typeDefinitions, err = GenerateTypeDefinitions(t, spec, ops, opts.OutputOptions.ExcludeSchemas) if err != nil { @@ -150,6 +151,10 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) { return "", fmt.Errorf("error generating constants: %w", err) } + xGoTypeImports, err = GetTypeDefinitionsImports(spec, opts.OutputOptions.ExcludeSchemas) + if err != nil { + return "", fmt.Errorf("error getting type definition imports: %w", err) + } } var echoServerOut string @@ -228,7 +233,7 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) { var buf bytes.Buffer w := bufio.NewWriter(&buf) - externalImports := importMapping.GoImports() + externalImports := append(importMapping.GoImports(), importMap(xGoTypeImports).GoImports()...) importsOut, err := GenerateImports(t, externalImports, opts.PackageName) if err != nil { return "", fmt.Errorf("error generating imports: %w", err) @@ -787,3 +792,97 @@ func LoadTemplates(src embed.FS, t *template.Template) error { return nil }) } + +func GetTypeDefinitionsImports(swagger *openapi3.T, excludeSchemas []string) (map[string]goImport, error) { + res := map[string]goImport{} + schemaImports, err := GetSchemaImports(swagger.Components.Schemas, excludeSchemas) + if err != nil { + return nil, err + } + + reqBodiesImports, err := GetRequestBodiesImports(swagger.Components.RequestBodies) + if err != nil { + return nil, err + } + + responsesImports, err := GetResponsesImports(swagger.Components.Responses) + if err != nil { + return nil, err + } + + for _, imprts := range []map[string]goImport{schemaImports, reqBodiesImports, responsesImports} { + for k, v := range imprts { + res[k] = v + } + } + return res, nil +} + +func GetSchemaImports(schemas map[string]*openapi3.SchemaRef, excludeSchemas []string) (map[string]goImport, error) { + var err error + res := map[string]goImport{} + excludeSchemasMap := make(map[string]bool) + for _, schema := range excludeSchemas { + excludeSchemasMap[schema] = true + } + for _, schemaName := range SortedSchemaKeys(schemas) { + if _, ok := excludeSchemasMap[schemaName]; ok { + continue + } + schema := schemas[schemaName].Value + + if schema == nil || schema.Properties == nil { + continue + } + + res, err = GetImports(schema.Properties) + if err != nil { + return nil, err + } + } + return res, nil +} + +func GetRequestBodiesImports(bodies map[string]*openapi3.RequestBodyRef) (map[string]goImport, error) { + var res map[string]goImport + var err error + for _, requestBodyName := range SortedRequestBodyKeys(bodies) { + requestBodyRef := bodies[requestBodyName] + response := requestBodyRef.Value + jsonBody, found := response.Content["application/json"] + if found { + schema := jsonBody.Schema + if schema == nil || schema.Value == nil || schema.Value.Properties == nil { + continue + } + + res, err = GetImports(schema.Value.Properties) + if err != nil { + return nil, err + } + } + } + return res, nil +} + +func GetResponsesImports(responses map[string]*openapi3.ResponseRef) (map[string]goImport, error) { + var res map[string]goImport + var err error + for _, responseName := range SortedResponsesKeys(responses) { + responseOrRef := responses[responseName] + response := responseOrRef.Value + jsonResponse, found := response.Content["application/json"] + if found { + schema := jsonResponse.Schema + if schema == nil || schema.Value == nil || schema.Value.Properties == nil { + continue + } + + res, err = GetImports(schema.Value.Properties) + if err != nil { + return nil, err + } + } + } + return res, nil +} diff --git a/pkg/codegen/codegen_test.go b/pkg/codegen/codegen_test.go index 8d333a6e4..d5840b354 100644 --- a/pkg/codegen/codegen_test.go +++ b/pkg/codegen/codegen_test.go @@ -8,6 +8,10 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + + "github.com/deepmap/oapi-codegen/pkg/util" + "github.com/getkin/kin-openapi/openapi3" "github.com/golangci/lint-1" "github.com/stretchr/testify/assert" @@ -16,6 +20,13 @@ import ( examplePetstore "github.com/deepmap/oapi-codegen/examples/petstore-expanded/echo/api" ) +func checkLint(t *testing.T, filename string, code []byte) { + linter := new(lint.Linter) + problems, err := linter.Lint("test.gen.go", code) + assert.NoError(t, err) + assert.Len(t, problems, 0) +} + func TestExamplePetStoreCodeGeneration(t *testing.T) { // Input vars for code generation: @@ -58,10 +69,7 @@ func TestExamplePetStoreCodeGeneration(t *testing.T) { `) // Make sure the generated code is valid: - linter := new(lint.Linter) - problems, err := linter.Lint("test.gen.go", []byte(code)) - assert.NoError(t, err) - assert.Len(t, problems, 0) + checkLint(t, "test.gen.go", []byte(code)) } func TestExamplePetStoreCodeGenerationWithUserTemplates(t *testing.T) { @@ -179,10 +187,42 @@ type GetTestByNameResponse struct { assert.Contains(t, code, "DeadSince *time.Time `json:\"dead_since,omitempty\" tag1:\"value1\" tag2:\"value2\"`") // Make sure the generated code is valid: - linter := new(lint.Linter) - problems, err := linter.Lint("test.gen.go", []byte(code)) + checkLint(t, "test.gen.go", []byte(code)) +} + +func TestXGoTypeImport(t *testing.T) { + packageName := "api" + opts := Configuration{ + PackageName: packageName, + Generate: GenerateOptions{ + Models: true, + }, + } + spec := "test_specs/x-go-type-import-pet.yaml" + swagger, err := util.LoadSwagger(spec) + require.NoError(t, err) + + // Run our code generation: + code, err := Generate(swagger, opts) assert.NoError(t, err) - assert.Len(t, problems, 0) + assert.NotEmpty(t, code) + + // Check that we have valid (formattable) code: + _, err = format.Source([]byte(code)) + assert.NoError(t, err) + + // Check that we have a package: + assert.Contains(t, code, "package api") + + // Check import + assert.Contains(t, code, `myuuid "github.com/google/uuid"`) + + // Check generated struct + assert.Contains(t, code, "type Pet struct {\n\tAge myuuid.UUID `json:\"age\"`\n}") + + // Make sure the generated code is valid: + checkLint(t, "test.gen.go", []byte(code)) + } //go:embed test_spec.yaml diff --git a/pkg/codegen/test_specs/x-go-type-import-pet.yaml b/pkg/codegen/test_specs/x-go-type-import-pet.yaml new file mode 100644 index 000000000..4fea6ff9e --- /dev/null +++ b/pkg/codegen/test_specs/x-go-type-import-pet.yaml @@ -0,0 +1,47 @@ +openapi: "3.0.0" +info: + version: 1.0.0 + title: Swagger Petstore + description: A sample API that uses a petstore as an example to demonstrate features in the OpenAPI 3.0 specification + termsOfService: http://swagger.io/terms/ + contact: + name: Swagger API Team + email: apiteam@swagger.io + url: http://swagger.io + license: + name: Apache 2.0 + url: https://www.apache.org/licenses/LICENSE-2.0.html +servers: + - url: http://petstore.swagger.io/api +paths: + /pets/{id}: + get: + summary: Returns a pet by ID + description: Returns a pet based on a single ID + operationId: findPetByID + parameters: + - name: id + in: path + description: ID of pet to fetch + required: true + schema: + type: integer + format: int64 + responses: + '200': + description: pet response + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' +components: + schemas: + Pet: + properties: + age: + x-go-type: myuuid.UUID + x-go-type-import: + path: github.com/google/uuid + name: myuuid + required: + - age diff --git a/pkg/codegen/utils.go b/pkg/codegen/utils.go index 557b919f5..afa2059ee 100644 --- a/pkg/codegen/utils.go +++ b/pkg/codegen/utils.go @@ -14,6 +14,7 @@ package codegen import ( + "encoding/json" "fmt" "net/url" "regexp" @@ -801,3 +802,28 @@ func findSchemaNameByRefPath(refPath string, spec *openapi3.T) (string, error) { } return "", nil } + +func GetImports(dict map[string]*openapi3.SchemaRef) (map[string]goImport, error) { + res := map[string]goImport{} + for _, v := range dict { + if v == nil || v.Value == nil { + continue + } + + if v.Value.Extensions["x-go-type-import"] == nil || v.Value.Extensions["x-go-type"] == nil { + continue + } + goTypeImportExt := v.Value.Extensions["x-go-type-import"] + + if raw, ok := goTypeImportExt.(json.RawMessage); ok { + gi := goImport{} + if err := json.Unmarshal(raw, &gi); err != nil { + return nil, err + } + res[gi.String()] = gi + } else { + continue + } + } + return res, nil +} diff --git a/pkg/codegen/utils_test.go b/pkg/codegen/utils_test.go index 0b13ee8b1..64c8b9133 100644 --- a/pkg/codegen/utils_test.go +++ b/pkg/codegen/utils_test.go @@ -14,6 +14,8 @@ package codegen import ( + "encoding/json" + "fmt" "testing" "github.com/getkin/kin-openapi/openapi3" @@ -376,3 +378,44 @@ func TestSchemaNameToTypeName(t *testing.T) { assert.Equal(t, want, SchemaNameToTypeName(in)) } } + +func TestGetImports(t *testing.T) { + schemas := map[string]*openapi3.SchemaRef{ + "age": { + Value: &openapi3.Schema{ + ExtensionProps: openapi3.ExtensionProps{ + Extensions: map[string]interface{}{ + "x-go-type-import": json.RawMessage( + `{"name": "hello", "path": "github.com/google/uuid"}`, + ), + "x-go-type": json.RawMessage( + "hello.UUID", + ), + }, + }, + }, + }, + "name": { + Value: &openapi3.Schema{ + ExtensionProps: openapi3.ExtensionProps{ + Extensions: map[string]interface{}{"other-tag": json.RawMessage( + `bla`, + )}, + }, + }, + }, + "value": nil, + } + + expected := map[string]goImport{ + fmt.Sprintf("%s %q", "hello", "github.com/google/uuid"): { + Name: "hello", + Path: "github.com/google/uuid", + }, + } + + res, err := GetImports(schemas) + + assert.NoError(t, err) + assert.Equal(t, expected, res) +}