Skip to content

Commit

Permalink
Add support for Gorilla generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie Tanna committed May 20, 2022
1 parent ec460d3 commit 9f36062
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 18 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ func SetupHandler() {
http.Handle("/", Handler(&myApi))
}
```

Alternatively, [Gorilla](https://github.com/gorilla/mux) is also 100% compatible with `net/http` and can be generated with `-generate gorilla`.

</summary></details>

#### Additional Properties in type definitions
Expand Down
4 changes: 3 additions & 1 deletion cmd/oapi-codegen/oapi-codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func main() {

flag.StringVar(&flagPackageName, "package", "", "The package name for generated code")
flag.StringVar(&flagGenerate, "generate", "types,client,server,spec",
`Comma-separated list of code to generate; valid options: "types", "client", "chi-server", "server", "gin", "spec", "skip-fmt", "skip-prune"`)
`Comma-separated list of code to generate; valid options: "types", "client", "chi-server", "server", "gin", "gorilla", "spec", "skip-fmt", "skip-prune"`)
flag.StringVar(&flagOutputFile, "o", "", "Where to output generated code, stdout is default")
flag.StringVar(&flagIncludeTags, "include-tags", "", "Only include operations with the given tags. Comma-separated list of tags.")
flag.StringVar(&flagExcludeTags, "exclude-tags", "", "Exclude operations that are tagged with the given tags. Comma-separated list of tags.")
Expand Down Expand Up @@ -126,6 +126,8 @@ func main() {
opts.GenerateEchoServer = true
case "gin":
opts.GenerateGinServer = true
case "gorilla":
opts.GenerateGorillaServer = true
case "types":
opts.GenerateTypes = true
case "spec":
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/golang/protobuf v1.5.2 // indirect
github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219
github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0
github.com/json-iterator/go v1.1.12 // indirect
github.com/labstack/echo/v4 v4.7.2
github.com/lestrrat-go/blackmagic v1.0.1 // indirect
Expand Down
50 changes: 33 additions & 17 deletions pkg/codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,24 @@ var templates embed.FS

// Options defines the optional code to generate.
type Options struct {
GenerateChiServer bool // GenerateChiServer specifies whether to generate chi server boilerplate
GenerateEchoServer bool // GenerateEchoServer specifies whether to generate echo server boilerplate
GenerateGinServer bool // GenerateGinServer specifies whether to generate echo server boilerplate
GenerateClient bool // GenerateClient specifies whether to generate client boilerplate
GenerateTypes bool // GenerateTypes specifies whether to generate type definitions
EmbedSpec bool // Whether to embed the swagger spec in the generated code
SkipFmt bool // Whether to skip go imports on the generated code
SkipPrune bool // Whether to skip pruning unused components on the generated code
AliasTypes bool // Whether to alias types if possible
IncludeTags []string // Only include operations that have one of these tags. Ignored when empty.
ExcludeTags []string // Exclude operations that have one of these tags. Ignored when empty.
UserTemplates map[string]string // Override built-in templates from user-provided files
ImportMapping map[string]string // ImportMapping specifies the golang package path for each external reference
ExcludeSchemas []string // Exclude from generation schemas with given names. Ignored when empty.
OldMergeSchemas bool // Schema merging for allOf was changed in a big way, when true, the old way is used
OldEnumConflicts bool // When set to true, we include the object path in enum names, otherwise, rely on global de-dup
ResponseTypeSuffix string // The suffix used for responses types
GenerateChiServer bool // GenerateChiServer specifies whether to generate chi server boilerplate
GenerateEchoServer bool // GenerateEchoServer specifies whether to generate echo server boilerplate
GenerateGinServer bool // GenerateGinServer specifies whether to generate echo server boilerplate
GenerateGorillaServer bool // GenerateGorillaServer specifies whether to generate Gorilla server boilerplate
GenerateClient bool // GenerateClient specifies whether to generate client boilerplate
GenerateTypes bool // GenerateTypes specifies whether to generate type definitions
EmbedSpec bool // Whether to embed the swagger spec in the generated code
SkipFmt bool // Whether to skip go imports on the generated code
SkipPrune bool // Whether to skip pruning unused components on the generated code
AliasTypes bool // Whether to alias types if possible
IncludeTags []string // Only include operations that have one of these tags. Ignored when empty.
ExcludeTags []string // Exclude operations that have one of these tags. Ignored when empty.
UserTemplates map[string]string // Override built-in templates from user-provided files
ImportMapping map[string]string // ImportMapping specifies the golang package path for each external reference
ExcludeSchemas []string // Exclude from generation schemas with given names. Ignored when empty.
OldMergeSchemas bool // Schema merging for allOf was changed in a big way, when true, the old way is used
OldEnumConflicts bool // When set to true, we include the object path in enum names, otherwise, rely on global de-dup
ResponseTypeSuffix string // The suffix used for responses types
}

// We store options globally to simplify accessing them from all the codegen
Expand Down Expand Up @@ -193,6 +194,14 @@ func Generate(swagger *openapi3.T, packageName string, opts Options) (string, er
}
}

var gorillaServerOut string
if opts.GenerateGorillaServer {
gorillaServerOut, err = GenerateGorillaServer(t, ops)
if err != nil {
return "", fmt.Errorf("error generating Go handlers for Paths: %w", err)
}
}

var clientOut string
if opts.GenerateClient {
clientOut, err = GenerateClient(t, ops)
Expand Down Expand Up @@ -273,6 +282,13 @@ func Generate(swagger *openapi3.T, packageName string, opts Options) (string, er
}
}

if opts.GenerateGorillaServer {
_, err = w.WriteString(gorillaServerOut)
if err != nil {
return "", fmt.Errorf("error writing server path handlers: %w", err)
}
}

if opts.EmbedSpec {
_, err = w.WriteString(inlinedSpec)
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions pkg/codegen/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,12 @@ func GenerateGinServer(t *template.Template, operations []OperationDefinition) (
return GenerateTemplates([]string{"gin/gin-interface.tmpl", "gin/gin-wrappers.tmpl", "gin/gin-register.tmpl"}, t, operations)
}

// GenerateGinServer This function generates all the go code for the ServerInterface as well as
// all the wrapper functions around our handlers.
func GenerateGorillaServer(t *template.Template, operations []OperationDefinition) (string, error) {
return GenerateTemplates([]string{"gorilla/gorilla-interface.tmpl", "gorilla/gorilla-middleware.tmpl", "gorilla/gorilla-register.tmpl"}, t, operations)
}

// Uses the template engine to generate the function which registers our wrappers
// as Echo path handlers.
func GenerateClient(t *template.Template, ops []OperationDefinition) (string, error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/codegen/template_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ var TemplateFunctions = template.FuncMap{
"swaggerUriToEchoUri": SwaggerUriToEchoUri,
"swaggerUriToChiUri": SwaggerUriToChiUri,
"swaggerUriToGinUri": SwaggerUriToGinUri,
"swaggerUriToGorillaUri": SwaggerUriToGorillaUri,
"lcFirst": LowercaseFirstCharacter,
"ucFirst": UppercaseFirstCharacter,
"camelCase": ToCamelCase,
Expand Down
7 changes: 7 additions & 0 deletions pkg/codegen/templates/gorilla/gorilla-interface.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// ServerInterface represents all server handlers.
type ServerInterface interface {
{{range .}}{{.SummaryAsComment }}
// ({{.Method}} {{.Path}})
{{.OperationId}}(w http.ResponseWriter, r *http.Request{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}})
{{end}}
}
251 changes: 251 additions & 0 deletions pkg/codegen/templates/gorilla/gorilla-middleware.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}

type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc

{{range .}}{{$opid := .OperationId}}

// {{$opid}} operation middleware
func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
{{if or .RequiresParamObject (gt (len .PathParams) 0) }}
var err error
{{end}}

{{range .PathParams}}// ------------- Path parameter "{{.ParamName}}" -------------
var {{$varName := .GoVariableName}}{{$varName}} {{.TypeDef}}

{{if .IsPassThrough}}
{{$varName}} = mux.Vars(r)["{{.ParamName}}"]
{{end}}
{{if .IsJson}}
err = json.Unmarshal([]byte(mux.Vars(r)["{{.ParamName}}"]), &{{$varName}})
if err != nil {
siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err})
return
}
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", mux.Vars(r)["{{.ParamName}}"], &{{$varName}})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err})
return
}
{{end}}

{{end}}

{{range .SecurityDefinitions}}
ctx = context.WithValue(ctx, {{.ProviderName | ucFirst}}Scopes, {{toStringArray .Scopes}})
{{end}}

{{if .RequiresParamObject}}
// Parameter object where we will unmarshal all parameters from the context
var params {{.OperationId}}Params

{{range $paramIdx, $param := .QueryParams}}// ------------- {{if .Required}}Required{{else}}Optional{{end}} query parameter "{{.ParamName}}" -------------
if paramValue := r.URL.Query().Get("{{.ParamName}}"); paramValue != "" {

{{if .IsPassThrough}}
params.{{.GoName}} = {{if not .Required}}&{{end}}paramValue
{{end}}

{{if .IsJson}}
var value {{.TypeDef}}
err = json.Unmarshal([]byte(paramValue), &value)
if err != nil {
siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err})
return
}

params.{{.GoName}} = {{if not .Required}}&{{end}}value
{{end}}
}{{if .Required}} else {
siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "{{.ParamName}}"})
return
}{{end}}
{{if .IsStyled}}
err = runtime.BindQueryParameter("{{.Style}}", {{.Explode}}, {{.Required}}, "{{.ParamName}}", r.URL.Query(), &params.{{.GoName}})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err})
return
}
{{end}}
{{end}}

{{if .HeaderParams}}
headers := r.Header

{{range .HeaderParams}}// ------------- {{if .Required}}Required{{else}}Optional{{end}} header parameter "{{.ParamName}}" -------------
if valueList, found := headers[http.CanonicalHeaderKey("{{.ParamName}}")]; found {
var {{.GoName}} {{.TypeDef}}
n := len(valueList)
if n != 1 {
siw.ErrorHandlerFunc(w, r, &TooManyValuesForParamError{ParamName: "{{.ParamName}}", Count: n})
return
}

{{if .IsPassThrough}}
params.{{.GoName}} = {{if not .Required}}&{{end}}valueList[0]
{{end}}

{{if .IsJson}}
err = json.Unmarshal([]byte(valueList[0]), &{{.GoName}})
if err != nil {
siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err})
return
}
{{end}}

{{if .IsStyled}}
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err})
return
}
{{end}}

params.{{.GoName}} = {{if not .Required}}&{{end}}{{.GoName}}

} {{if .Required}}else {
err := fmt.Errorf("Header parameter {{.ParamName}} is required, but not found")
siw.ErrorHandlerFunc(w, r, &RequiredHeaderError{ParamName: "{{.ParamName}}", Err: err})
return
}{{end}}

{{end}}
{{end}}

{{range .CookieParams}}
var cookie *http.Cookie

if cookie, err = r.Cookie("{{.ParamName}}"); err == nil {

{{- if .IsPassThrough}}
params.{{.GoName}} = {{if not .Required}}&{{end}}cookie.Value
{{end}}

{{- if .IsJson}}
var value {{.TypeDef}}
var decoded string
decoded, err := url.QueryUnescape(cookie.Value)
if err != nil {
err = fmt.Errorf("Error unescaping cookie parameter '{{.ParamName}}'")
siw.ErrorHandlerFunc(w, r, &UnescapedCookieParamError{ParamName: "{{.ParamName}}", Err: err})
return
}

err = json.Unmarshal([]byte(decoded), &value)
if err != nil {
siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err})
return
}

params.{{.GoName}} = {{if not .Required}}&{{end}}value
{{end}}

{{- if .IsStyled}}
var value {{.TypeDef}}
err = runtime.BindStyledParameter("simple",{{.Explode}}, "{{.ParamName}}", cookie.Value, &value)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err})
return
}
params.{{.GoName}} = {{if not .Required}}&{{end}}value
{{end}}

}

{{- if .Required}} else {
siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "{{.ParamName}}"})
return
}
{{- end}}
{{end}}
{{end}}

var handler = func(w http.ResponseWriter, r *http.Request) {
siw.Handler.{{.OperationId}}(w, r{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}})
}

for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}

handler(w, r.WithContext(ctx))
}
{{end}}

type UnescapedCookieParamError struct {
ParamName string
Err error
}

func (e *UnescapedCookieParamError) Error() string {
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
}

func (e *UnescapedCookieParamError) Unwrap() error {
return e.Err
}

type UnmarshalingParamError struct {
ParamName string
Err error
}

func (e *UnmarshalingParamError) Error() string {
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
}

func (e *UnmarshalingParamError) Unwrap() error {
return e.Err
}

type RequiredParamError struct {
ParamName string
}

func (e *RequiredParamError) Error() string {
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
}

type RequiredHeaderError struct {
ParamName string
Err error
}

func (e *RequiredHeaderError) Error() string {
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
}

func (e *RequiredHeaderError) Unwrap() error {
return e.Err
}

type InvalidParamFormatError struct {
ParamName string
Err error
}

func (e *InvalidParamFormatError) Error() string {
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
}

func (e *InvalidParamFormatError) Unwrap() error {
return e.Err
}

type TooManyValuesForParamError struct {
ParamName string
Count int
}

func (e *TooManyValuesForParamError) Error() string {
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
}

0 comments on commit 9f36062

Please sign in to comment.