From 3010f54b768e1f76f33319ef2ec14d3aad8ca242 Mon Sep 17 00:00:00 2001 From: Jamie Tanna Date: Thu, 19 May 2022 20:39:42 +0100 Subject: [PATCH 1/2] Add support for Gorilla generation As part of #465, it'd be handy to have gorilla/mux, a commonly used HTTP server as a generated server. This is very similar to Chi, as it is also `net/http` compliant, and allows us to mostly copy-paste the code, with very minor tweaks for Gorilla-specific routing needs. --- README.md | 3 + cmd/oapi-codegen/oapi-codegen.go | 4 +- pkg/codegen/codegen.go | 15 ++ pkg/codegen/configuration.go | 13 +- pkg/codegen/operations.go | 6 + pkg/codegen/template_helpers.go | 1 + .../templates/gorilla/gorilla-interface.tmpl | 7 + .../templates/gorilla/gorilla-middleware.tmpl | 251 ++++++++++++++++++ .../templates/gorilla/gorilla-register.tmpl | 49 ++++ pkg/codegen/utils.go | 15 ++ pkg/codegen/utils_test.go | 17 ++ 11 files changed, 374 insertions(+), 7 deletions(-) create mode 100644 pkg/codegen/templates/gorilla/gorilla-interface.tmpl create mode 100644 pkg/codegen/templates/gorilla/gorilla-middleware.tmpl create mode 100644 pkg/codegen/templates/gorilla/gorilla-register.tmpl diff --git a/README.md b/README.md index 0bf46d95b..d7297bd18 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,9 @@ func SetupHandler() { http.Handle("/", Handler(&myApi)) } ``` + +Alternatively, [Gorilla](https://github.com/gorilla/mux) is also 100% compatible with `net/http` and can be generated with `-generate gorilla`. + #### Additional Properties in type definitions diff --git a/cmd/oapi-codegen/oapi-codegen.go b/cmd/oapi-codegen/oapi-codegen.go index fe43be6a8..ab18aa4e4 100644 --- a/cmd/oapi-codegen/oapi-codegen.go +++ b/cmd/oapi-codegen/oapi-codegen.go @@ -86,7 +86,7 @@ func main() { // All flags below are deprecated, and will be removed in a future release. Please do not // update their behavior. flag.StringVar(&flagGenerate, "generate", "types,client,server,spec", - `Comma-separated list of code to generate; valid options: "types", "client", "chi-server", "server", "gin", "spec", "skip-fmt", "skip-prune"`) + `Comma-separated list of code to generate; valid options: "types", "client", "chi-server", "server", "gin", "gorilla", "spec", "skip-fmt", "skip-prune"`) flag.StringVar(&flagIncludeTags, "include-tags", "", "Only include operations with the given tags. Comma-separated list of tags.") flag.StringVar(&flagExcludeTags, "exclude-tags", "", "Exclude operations that are tagged with the given tags. Comma-separated list of tags.") flag.StringVar(&flagTemplatesDir, "templates", "", "Path to directory containing user templates") @@ -321,6 +321,8 @@ func newConfigFromOldConfig(c oldConfiguration) configuration { opts.Generate.EchoServer = true case "gin": opts.Generate.GinServer = true + case "gorilla": + opts.Generate.GorillaServer = true case "types": opts.Generate.Models = true case "spec": diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index 0fc91dc8c..ed43e2dcb 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -176,6 +176,14 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) { } } + var gorillaServerOut string + if opts.Generate.GorillaServer { + gorillaServerOut, err = GenerateGorillaServer(t, ops) + if err != nil { + return "", fmt.Errorf("error generating Go handlers for Paths: %w", err) + } + } + var clientOut string if opts.Generate.Client { clientOut, err = GenerateClient(t, ops) @@ -256,6 +264,13 @@ func Generate(spec *openapi3.T, opts Configuration) (string, error) { } } + if opts.Generate.GorillaServer { + _, err = w.WriteString(gorillaServerOut) + if err != nil { + return "", fmt.Errorf("error writing server path handlers: %w", err) + } + } + if opts.Generate.EmbeddedSpec { _, err = w.WriteString(inlinedSpec) if err != nil { diff --git a/pkg/codegen/configuration.go b/pkg/codegen/configuration.go index 084ab5b7f..5dd98eb75 100644 --- a/pkg/codegen/configuration.go +++ b/pkg/codegen/configuration.go @@ -16,12 +16,13 @@ type Configuration struct { // GenerateOptions specifies which supported output formats to generate. type GenerateOptions struct { - ChiServer bool `yaml:"chi-server,omitempty"` // ChiServer specifies whether to generate chi server boilerplate - EchoServer bool `yaml:"echo-server,omitempty"` // EchoServer specifies whether to generate echo server boilerplate - GinServer bool `yaml:"gin-server,omitempty"` // GinServer specifies whether to generate echo server boilerplate - Client bool `yaml:"client,omitempty"` // Client specifies whether to generate client boilerplate - Models bool `yaml:"models,omitempty"` // Models specifies whether to generate type definitions - EmbeddedSpec bool `yaml:"embedded-spec,omitempty"` // Whether to embed the swagger spec in the generated code + ChiServer bool `yaml:"chi-server,omitempty"` // ChiServer specifies whether to generate chi server boilerplate + EchoServer bool `yaml:"echo-server,omitempty"` // EchoServer specifies whether to generate echo server boilerplate + GinServer bool `yaml:"gin-server,omitempty"` // GinServer specifies whether to generate echo server boilerplate + GorillaServer bool `yaml:"gorilla-server,omitempty"` // GorillaServer specifies whether to generate Gorilla server boilerplate + Client bool `yaml:"client,omitempty"` // Client specifies whether to generate client boilerplate + Models bool `yaml:"models,omitempty"` // Models specifies whether to generate type definitions + EmbeddedSpec bool `yaml:"embedded-spec,omitempty"` // Whether to embed the swagger spec in the generated code } // CompatibilityOptions specifies backward compatibility settings for the diff --git a/pkg/codegen/operations.go b/pkg/codegen/operations.go index 3196d5cc6..e858482a9 100644 --- a/pkg/codegen/operations.go +++ b/pkg/codegen/operations.go @@ -683,6 +683,12 @@ func GenerateGinServer(t *template.Template, operations []OperationDefinition) ( return GenerateTemplates([]string{"gin/gin-interface.tmpl", "gin/gin-wrappers.tmpl", "gin/gin-register.tmpl"}, t, operations) } +// GenerateGinServer This function generates all the go code for the ServerInterface as well as +// all the wrapper functions around our handlers. +func GenerateGorillaServer(t *template.Template, operations []OperationDefinition) (string, error) { + return GenerateTemplates([]string{"gorilla/gorilla-interface.tmpl", "gorilla/gorilla-middleware.tmpl", "gorilla/gorilla-register.tmpl"}, t, operations) +} + // Uses the template engine to generate the function which registers our wrappers // as Echo path handlers. func GenerateClient(t *template.Template, ops []OperationDefinition) (string, error) { diff --git a/pkg/codegen/template_helpers.go b/pkg/codegen/template_helpers.go index eeb781020..3ff5e1e0f 100644 --- a/pkg/codegen/template_helpers.go +++ b/pkg/codegen/template_helpers.go @@ -274,6 +274,7 @@ var TemplateFunctions = template.FuncMap{ "swaggerUriToEchoUri": SwaggerUriToEchoUri, "swaggerUriToChiUri": SwaggerUriToChiUri, "swaggerUriToGinUri": SwaggerUriToGinUri, + "swaggerUriToGorillaUri": SwaggerUriToGorillaUri, "lcFirst": LowercaseFirstCharacter, "ucFirst": UppercaseFirstCharacter, "camelCase": ToCamelCase, diff --git a/pkg/codegen/templates/gorilla/gorilla-interface.tmpl b/pkg/codegen/templates/gorilla/gorilla-interface.tmpl new file mode 100644 index 000000000..79a51fd75 --- /dev/null +++ b/pkg/codegen/templates/gorilla/gorilla-interface.tmpl @@ -0,0 +1,7 @@ +// ServerInterface represents all server handlers. +type ServerInterface interface { +{{range .}}{{.SummaryAsComment }} +// ({{.Method}} {{.Path}}) +{{.OperationId}}(w http.ResponseWriter, r *http.Request{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}}) +{{end}} +} diff --git a/pkg/codegen/templates/gorilla/gorilla-middleware.tmpl b/pkg/codegen/templates/gorilla/gorilla-middleware.tmpl new file mode 100644 index 000000000..a08bc5e9f --- /dev/null +++ b/pkg/codegen/templates/gorilla/gorilla-middleware.tmpl @@ -0,0 +1,251 @@ +// ServerInterfaceWrapper converts contexts to parameters. +type ServerInterfaceWrapper struct { + Handler ServerInterface + HandlerMiddlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc + +{{range .}}{{$opid := .OperationId}} + +// {{$opid}} operation middleware +func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + {{if or .RequiresParamObject (gt (len .PathParams) 0) }} + var err error + {{end}} + + {{range .PathParams}}// ------------- Path parameter "{{.ParamName}}" ------------- + var {{$varName := .GoVariableName}}{{$varName}} {{.TypeDef}} + + {{if .IsPassThrough}} + {{$varName}} = mux.Vars(r)["{{.ParamName}}"] + {{end}} + {{if .IsJson}} + err = json.Unmarshal([]byte(mux.Vars(r)["{{.ParamName}}"]), &{{$varName}}) + if err != nil { + siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err}) + return + } + {{end}} + {{if .IsStyled}} + err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", mux.Vars(r)["{{.ParamName}}"], &{{$varName}}) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err}) + return + } + {{end}} + + {{end}} + +{{range .SecurityDefinitions}} + ctx = context.WithValue(ctx, {{.ProviderName | ucFirst}}Scopes, {{toStringArray .Scopes}}) +{{end}} + + {{if .RequiresParamObject}} + // Parameter object where we will unmarshal all parameters from the context + var params {{.OperationId}}Params + + {{range $paramIdx, $param := .QueryParams}}// ------------- {{if .Required}}Required{{else}}Optional{{end}} query parameter "{{.ParamName}}" ------------- + if paramValue := r.URL.Query().Get("{{.ParamName}}"); paramValue != "" { + + {{if .IsPassThrough}} + params.{{.GoName}} = {{if not .Required}}&{{end}}paramValue + {{end}} + + {{if .IsJson}} + var value {{.TypeDef}} + err = json.Unmarshal([]byte(paramValue), &value) + if err != nil { + siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err}) + return + } + + params.{{.GoName}} = {{if not .Required}}&{{end}}value + {{end}} + }{{if .Required}} else { + siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "{{.ParamName}}"}) + return + }{{end}} + {{if .IsStyled}} + err = runtime.BindQueryParameter("{{.Style}}", {{.Explode}}, {{.Required}}, "{{.ParamName}}", r.URL.Query(), ¶ms.{{.GoName}}) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err}) + return + } + {{end}} + {{end}} + + {{if .HeaderParams}} + headers := r.Header + + {{range .HeaderParams}}// ------------- {{if .Required}}Required{{else}}Optional{{end}} header parameter "{{.ParamName}}" ------------- + if valueList, found := headers[http.CanonicalHeaderKey("{{.ParamName}}")]; found { + var {{.GoName}} {{.TypeDef}} + n := len(valueList) + if n != 1 { + siw.ErrorHandlerFunc(w, r, &TooManyValuesForParamError{ParamName: "{{.ParamName}}", Count: n}) + return + } + + {{if .IsPassThrough}} + params.{{.GoName}} = {{if not .Required}}&{{end}}valueList[0] + {{end}} + + {{if .IsJson}} + err = json.Unmarshal([]byte(valueList[0]), &{{.GoName}}) + if err != nil { + siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err}) + return + } + {{end}} + + {{if .IsStyled}} + err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}}) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err}) + return + } + {{end}} + + params.{{.GoName}} = {{if not .Required}}&{{end}}{{.GoName}} + + } {{if .Required}}else { + err := fmt.Errorf("Header parameter {{.ParamName}} is required, but not found") + siw.ErrorHandlerFunc(w, r, &RequiredHeaderError{ParamName: "{{.ParamName}}", Err: err}) + return + }{{end}} + + {{end}} + {{end}} + + {{range .CookieParams}} + var cookie *http.Cookie + + if cookie, err = r.Cookie("{{.ParamName}}"); err == nil { + + {{- if .IsPassThrough}} + params.{{.GoName}} = {{if not .Required}}&{{end}}cookie.Value + {{end}} + + {{- if .IsJson}} + var value {{.TypeDef}} + var decoded string + decoded, err := url.QueryUnescape(cookie.Value) + if err != nil { + err = fmt.Errorf("Error unescaping cookie parameter '{{.ParamName}}'") + siw.ErrorHandlerFunc(w, r, &UnescapedCookieParamError{ParamName: "{{.ParamName}}", Err: err}) + return + } + + err = json.Unmarshal([]byte(decoded), &value) + if err != nil { + siw.ErrorHandlerFunc(w, r, &UnmarshalingParamError{ParamName: "{{.ParamName}}", Err: err}) + return + } + + params.{{.GoName}} = {{if not .Required}}&{{end}}value + {{end}} + + {{- if .IsStyled}} + var value {{.TypeDef}} + err = runtime.BindStyledParameter("simple",{{.Explode}}, "{{.ParamName}}", cookie.Value, &value) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "{{.ParamName}}", Err: err}) + return + } + params.{{.GoName}} = {{if not .Required}}&{{end}}value + {{end}} + + } + + {{- if .Required}} else { + siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "{{.ParamName}}"}) + return + } + {{- end}} + {{end}} + {{end}} + + var handler = func(w http.ResponseWriter, r *http.Request) { + siw.Handler.{{.OperationId}}(w, r{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}}) +} + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler(w, r.WithContext(ctx)) +} +{{end}} + +type UnescapedCookieParamError struct { + ParamName string + Err error +} + +func (e *UnescapedCookieParamError) Error() string { + return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName) +} + +func (e *UnescapedCookieParamError) Unwrap() error { + return e.Err +} + +type UnmarshalingParamError struct { + ParamName string + Err error +} + +func (e *UnmarshalingParamError) Error() string { + return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error()) +} + +func (e *UnmarshalingParamError) Unwrap() error { + return e.Err +} + +type RequiredParamError struct { + ParamName string +} + +func (e *RequiredParamError) Error() string { + return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName) +} + +type RequiredHeaderError struct { + ParamName string + Err error +} + +func (e *RequiredHeaderError) Error() string { + return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName) +} + +func (e *RequiredHeaderError) Unwrap() error { + return e.Err +} + +type InvalidParamFormatError struct { + ParamName string + Err error +} + +func (e *InvalidParamFormatError) Error() string { + return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error()) +} + +func (e *InvalidParamFormatError) Unwrap() error { + return e.Err +} + +type TooManyValuesForParamError struct { + ParamName string + Count int +} + +func (e *TooManyValuesForParamError) Error() string { + return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count) +} + diff --git a/pkg/codegen/templates/gorilla/gorilla-register.tmpl b/pkg/codegen/templates/gorilla/gorilla-register.tmpl new file mode 100644 index 000000000..b019e1dbb --- /dev/null +++ b/pkg/codegen/templates/gorilla/gorilla-register.tmpl @@ -0,0 +1,49 @@ +// Handler creates http.Handler with routing matching OpenAPI spec. +func Handler(si ServerInterface) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions{}) +} + +type GorillaServerOptions struct { + BaseURL string + BaseRouter *mux.Router + Middlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux. +func HandlerFromMux(si ServerInterface, r *mux.Router) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions { + BaseRouter: r, + }) +} + +func HandlerFromMuxWithBaseURL(si ServerInterface, r *mux.Router, baseURL string) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions { + BaseURL: baseURL, + BaseRouter: r, + }) +} + +// HandlerWithOptions creates http.Handler with additional options +func HandlerWithOptions(si ServerInterface, options GorillaServerOptions) http.Handler { +r := options.BaseRouter + +if r == nil { +r = mux.NewRouter() +} +if options.ErrorHandlerFunc == nil { + options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusBadRequest) + } +} +{{if .}}wrapper := ServerInterfaceWrapper{ +Handler: si, +HandlerMiddlewares: options.Middlewares, +ErrorHandlerFunc: options.ErrorHandlerFunc, +} +{{end}} +{{range .}} +r.HandleFunc(options.BaseURL+"{{.Path | swaggerUriToGorillaUri }}", wrapper.{{.OperationId}}).Methods("{{.Method }}") +{{end}} +return r +} diff --git a/pkg/codegen/utils.go b/pkg/codegen/utils.go index 034442e57..c713002b4 100644 --- a/pkg/codegen/utils.go +++ b/pkg/codegen/utils.go @@ -325,6 +325,21 @@ func SwaggerUriToGinUri(uri string) string { return pathParamRE.ReplaceAllString(uri, ":$1") } +// This function converts a swagger style path URI with parameters to a +// Gorilla compatible path URI. We need to replace all of Swagger parameters with +// ":param". Valid input parameters are: +// {param} +// {param*} +// {.param} +// {.param*} +// {;param} +// {;param*} +// {?param} +// {?param*} +func SwaggerUriToGorillaUri(uri string) string { + return pathParamRE.ReplaceAllString(uri, "{$1}") +} + // Returns the argument names, in order, in a given URI string, so for // /path/{param1}/{.param2*}/{?param3}, it would return param1, param2, param3 func OrderedParamsFromUri(uri string) []string { diff --git a/pkg/codegen/utils_test.go b/pkg/codegen/utils_test.go index b391852a9..0b13ee8b1 100644 --- a/pkg/codegen/utils_test.go +++ b/pkg/codegen/utils_test.go @@ -263,6 +263,23 @@ func TestSwaggerUriToGinUri(t *testing.T) { assert.Equal(t, "/path/:arg/foo", SwaggerUriToGinUri("/path/{?arg*}/foo")) } +func TestSwaggerUriToGorillaUri(t *testing.T) { // TODO + assert.Equal(t, "/path", SwaggerUriToGorillaUri("/path")) + assert.Equal(t, "/path/{arg}", SwaggerUriToGorillaUri("/path/{arg}")) + assert.Equal(t, "/path/{arg1}/{arg2}", SwaggerUriToGorillaUri("/path/{arg1}/{arg2}")) + assert.Equal(t, "/path/{arg1}/{arg2}/foo", SwaggerUriToGorillaUri("/path/{arg1}/{arg2}/foo")) + + // Make sure all the exploded and alternate formats match too + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{arg}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{arg*}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{.arg}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{.arg*}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{;arg}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{;arg*}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{?arg}/foo")) + assert.Equal(t, "/path/{arg}/foo", SwaggerUriToGorillaUri("/path/{?arg*}/foo")) +} + func TestOrderedParamsFromUri(t *testing.T) { result := OrderedParamsFromUri("/path/{param1}/{.param2}/{;param3*}/foo") assert.EqualValues(t, []string{"param1", "param2", "param3"}, result) From a13f9aa39040d80bb5c51bc6b4b186b1536d772e Mon Sep 17 00:00:00 2001 From: Jamie Tanna Date: Fri, 20 May 2022 16:55:45 +0100 Subject: [PATCH 2/2] Add example project for Gorilla As part of #465, we should produce an example version of the Petstore API using Gorilla, to validate that this works, as well as showing sample usage to consumers. --- .../petstore-expanded/gorilla/api/cfg.yaml | 6 + .../gorilla/api/petstore.gen.go | 427 ++++++++++++++++++ .../petstore-expanded/gorilla/api/petstore.go | 128 ++++++ .../petstore-expanded/gorilla/petstore.go | 53 +++ .../gorilla/petstore_test.go | 168 +++++++ go.mod | 1 + 6 files changed, 783 insertions(+) create mode 100644 examples/petstore-expanded/gorilla/api/cfg.yaml create mode 100644 examples/petstore-expanded/gorilla/api/petstore.gen.go create mode 100644 examples/petstore-expanded/gorilla/api/petstore.go create mode 100644 examples/petstore-expanded/gorilla/petstore.go create mode 100644 examples/petstore-expanded/gorilla/petstore_test.go diff --git a/examples/petstore-expanded/gorilla/api/cfg.yaml b/examples/petstore-expanded/gorilla/api/cfg.yaml new file mode 100644 index 000000000..7f1a2d759 --- /dev/null +++ b/examples/petstore-expanded/gorilla/api/cfg.yaml @@ -0,0 +1,6 @@ +package: api +generate: + gorilla-server: true + models: true + embedded-spec: true +output: petstore.gen.go diff --git a/examples/petstore-expanded/gorilla/api/petstore.gen.go b/examples/petstore-expanded/gorilla/api/petstore.gen.go new file mode 100644 index 000000000..5ddb9e1a6 --- /dev/null +++ b/examples/petstore-expanded/gorilla/api/petstore.gen.go @@ -0,0 +1,427 @@ +// Package api provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/deepmap/oapi-codegen version (devel) DO NOT EDIT. +package api + +import ( + "bytes" + "compress/gzip" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "path" + "strings" + + "github.com/deepmap/oapi-codegen/pkg/runtime" + "github.com/getkin/kin-openapi/openapi3" + "github.com/gorilla/mux" +) + +// Error defines model for Error. +type Error struct { + // Error code + Code int32 `json:"code"` + + // Error message + Message string `json:"message"` +} + +// NewPet defines model for NewPet. +type NewPet struct { + // Name of the pet + Name string `json:"name"` + + // Type of the pet + Tag *string `json:"tag,omitempty"` +} + +// Pet defines model for Pet. +type Pet struct { + // Unique id of the pet + Id int64 `json:"id"` + + // Name of the pet + Name string `json:"name"` + + // Type of the pet + Tag *string `json:"tag,omitempty"` +} + +// FindPetsParams defines parameters for FindPets. +type FindPetsParams struct { + // tags to filter by + Tags *[]string `form:"tags,omitempty" json:"tags,omitempty"` + + // maximum number of results to return + Limit *int32 `form:"limit,omitempty" json:"limit,omitempty"` +} + +// AddPetJSONBody defines parameters for AddPet. +type AddPetJSONBody = NewPet + +// AddPetJSONRequestBody defines body for AddPet for application/json ContentType. +type AddPetJSONRequestBody = AddPetJSONBody + +// ServerInterface represents all server handlers. +type ServerInterface interface { + // Returns all pets + // (GET /pets) + FindPets(w http.ResponseWriter, r *http.Request, params FindPetsParams) + // Creates a new pet + // (POST /pets) + AddPet(w http.ResponseWriter, r *http.Request) + // Deletes a pet by ID + // (DELETE /pets/{id}) + DeletePet(w http.ResponseWriter, r *http.Request, id int64) + // Returns a pet by ID + // (GET /pets/{id}) + FindPetByID(w http.ResponseWriter, r *http.Request, id int64) +} + +// ServerInterfaceWrapper converts contexts to parameters. +type ServerInterfaceWrapper struct { + Handler ServerInterface + HandlerMiddlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc + +// FindPets operation middleware +func (siw *ServerInterfaceWrapper) FindPets(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var err error + + // Parameter object where we will unmarshal all parameters from the context + var params FindPetsParams + + // ------------- Optional query parameter "tags" ------------- + if paramValue := r.URL.Query().Get("tags"); paramValue != "" { + + } + + err = runtime.BindQueryParameter("form", true, false, "tags", r.URL.Query(), ¶ms.Tags) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "tags", Err: err}) + return + } + + // ------------- Optional query parameter "limit" ------------- + if paramValue := r.URL.Query().Get("limit"); paramValue != "" { + + } + + err = runtime.BindQueryParameter("form", true, false, "limit", r.URL.Query(), ¶ms.Limit) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "limit", Err: err}) + return + } + + var handler = func(w http.ResponseWriter, r *http.Request) { + siw.Handler.FindPets(w, r, params) + } + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler(w, r.WithContext(ctx)) +} + +// AddPet operation middleware +func (siw *ServerInterfaceWrapper) AddPet(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var handler = func(w http.ResponseWriter, r *http.Request) { + siw.Handler.AddPet(w, r) + } + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler(w, r.WithContext(ctx)) +} + +// DeletePet operation middleware +func (siw *ServerInterfaceWrapper) DeletePet(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var err error + + // ------------- Path parameter "id" ------------- + var id int64 + + err = runtime.BindStyledParameter("simple", false, "id", mux.Vars(r)["id"], &id) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "id", Err: err}) + return + } + + var handler = func(w http.ResponseWriter, r *http.Request) { + siw.Handler.DeletePet(w, r, id) + } + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler(w, r.WithContext(ctx)) +} + +// FindPetByID operation middleware +func (siw *ServerInterfaceWrapper) FindPetByID(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var err error + + // ------------- Path parameter "id" ------------- + var id int64 + + err = runtime.BindStyledParameter("simple", false, "id", mux.Vars(r)["id"], &id) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "id", Err: err}) + return + } + + var handler = func(w http.ResponseWriter, r *http.Request) { + siw.Handler.FindPetByID(w, r, id) + } + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler(w, r.WithContext(ctx)) +} + +type UnescapedCookieParamError struct { + ParamName string + Err error +} + +func (e *UnescapedCookieParamError) Error() string { + return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName) +} + +func (e *UnescapedCookieParamError) Unwrap() error { + return e.Err +} + +type UnmarshalingParamError struct { + ParamName string + Err error +} + +func (e *UnmarshalingParamError) Error() string { + return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error()) +} + +func (e *UnmarshalingParamError) Unwrap() error { + return e.Err +} + +type RequiredParamError struct { + ParamName string +} + +func (e *RequiredParamError) Error() string { + return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName) +} + +type RequiredHeaderError struct { + ParamName string + Err error +} + +func (e *RequiredHeaderError) Error() string { + return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName) +} + +func (e *RequiredHeaderError) Unwrap() error { + return e.Err +} + +type InvalidParamFormatError struct { + ParamName string + Err error +} + +func (e *InvalidParamFormatError) Error() string { + return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error()) +} + +func (e *InvalidParamFormatError) Unwrap() error { + return e.Err +} + +type TooManyValuesForParamError struct { + ParamName string + Count int +} + +func (e *TooManyValuesForParamError) Error() string { + return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count) +} + +// Handler creates http.Handler with routing matching OpenAPI spec. +func Handler(si ServerInterface) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions{}) +} + +type GorillaServerOptions struct { + BaseURL string + BaseRouter *mux.Router + Middlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux. +func HandlerFromMux(si ServerInterface, r *mux.Router) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions{ + BaseRouter: r, + }) +} + +func HandlerFromMuxWithBaseURL(si ServerInterface, r *mux.Router, baseURL string) http.Handler { + return HandlerWithOptions(si, GorillaServerOptions{ + BaseURL: baseURL, + BaseRouter: r, + }) +} + +// HandlerWithOptions creates http.Handler with additional options +func HandlerWithOptions(si ServerInterface, options GorillaServerOptions) http.Handler { + r := options.BaseRouter + + if r == nil { + r = mux.NewRouter() + } + if options.ErrorHandlerFunc == nil { + options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusBadRequest) + } + } + wrapper := ServerInterfaceWrapper{ + Handler: si, + HandlerMiddlewares: options.Middlewares, + ErrorHandlerFunc: options.ErrorHandlerFunc, + } + + r.HandleFunc(options.BaseURL+"/pets", wrapper.FindPets).Methods("GET") + + r.HandleFunc(options.BaseURL+"/pets", wrapper.AddPet).Methods("POST") + + r.HandleFunc(options.BaseURL+"/pets/{id}", wrapper.DeletePet).Methods("DELETE") + + r.HandleFunc(options.BaseURL+"/pets/{id}", wrapper.FindPetByID).Methods("GET") + + return r +} + +// Base64 encoded, gzipped, json marshaled Swagger object +var swaggerSpec = []string{ + + "H4sIAAAAAAAC/+RXW48budH9KwV+32OnNbEXedBTvB4vICBrT+LdvKznoYZdkmrBSw9Z1FgY6L8HRbZu", + "I3k2QYIgQV506WY1T51zqlj9bGz0YwwUJJv5s8l2TR7rzw8pxaQ/xhRHSsJUL9s4kH4PlG3iUTgGM2+L", + "od7rzDImj2LmhoO8fWM6I9uR2l9aUTK7znjKGVfffND+9iE0S+KwMrtdZxI9Fk40mPkvZtpwv/x+15mP", + "9HRHcok7oL+y3Uf0BHEJsiYYSS437Izg6jLup+34etwLoHV3hTdhQ+c+Lc38l2fz/4mWZm7+b3YUYjap", + "MJty2XUvk+HhEtLPgR8LAQ/nuE7F+MN3V8R4gZQHc7+73+llDsvYJA+CtuImj+zM3ODIQuj/mJ9wtaLU", + "czTdRLH53K7Bu7sF/EToTWdK0qC1yDifzU5idt2LJN5BRj86qsGyRoGSKQNqMlliIsAMGIC+tmUSYSAf", + "Q5aEQrAklJIoA4dKwaeRgj7pbX8DeSTLS7ZYt+qMY0sh09Eb5t2Idk3wpr85g5zns9nT01OP9XYf02o2", + "xebZnxbvP3z8/OF3b/qbfi3eVcNQ8vnT8jOlDVu6lvesLpmpGCzulLO7KU3TmQ2l3Ej5fX/T3+iT40gB", + "RzZz87Ze6syIsq6OmClB+mPVDHZO619ISgoZ0LnKJCxT9JWhvM1CvlGt/0umBGsl2VrKGSR+CR/RQ6YB", + "bAwDewpSPFCWHn5EshQwg5AfY4KMKxbhDBlHptBBIAtpHYMtGTL5kwUsgJ6kh3cUCAOgwCrhhgcELKtC", + "HaAFRlsc19Ae3peEDywlQRw4gouJfAcxBUwEtCIBcjShC2Q7sCXlkrUgHFkpuYfbwhk8g5Q0cu5gLG7D", + "AZPuRSlq0h0IB8tDCQIbTFwy/FqyxB4WAdZoYa0gMGeC0aEQwsBWilc6Fq2kNBcceORsOawAg2g2x9wd", + "r4rDQ+bjGhNJwj2Juh58dJSFCdiPlAZWpv7KG/QtIXT8WNDDwKjMJMzwqLltyLFAiAEkJolJKeElheGw", + "ew93CSlTEIVJgf0RQEkBYRNdkREFNhQooAJu5OqHx5L0GYtwfPKS0sT6Ei07zmeb1B30ozvqayHHAR2p", + "sEOnPFpKKJqYfvfwueSRwsDKskM1zxBdTJ06MJMVdXPNslpFs+5gQ2u2xSFoY0tD8eD4gVLs4ceYHhio", + "cPZxOJVBb1djO7QcGPsv4Uv4TENVomRYkprPxYeYagDFo2NSkVR8D1obHusDJ/I5uw6onFVLkxxcUR+q", + "O3u4W2Mm51phjJSm8EpzlZcEllgsP5RGOO730XWn8Rtyk3S8oZSwO99a6wR46A6FGPhh3cPPAiM5R0Eo", + "67kxxlxIK2lfRD0oFbivAi26PZf7J+3Tqkx2FcjBFqEEC5I4Sz2WNixIPfxQsiUgqd1gKHyoAu0U2ZKj", + "xBVO8+8+wKtbClbz2OIzBvC40pTJTWr18OfSQn10qltTj0rzzhFKd2g+gMVqkbSVkz1b2pM5piZzqEY1", + "iwoMHLojlKlwA2feA86KwbKUgRVqzghF9j6bhGw7nZFW9+vh7lSYytyEcUwkXPxJ52qmKd2Jv7X19l/0", + "iNORoR53i8HMzQ8cBj1f6rGRlABKuc4g54eF4Er7PizZCSV42BodBczcPBZK2+M5r+tMN42MdSoR8vUM", + "upyh2gVMCbf6P8u2Hns6nNTx5hyBx6/stY0X/0BJ55lEuTipsFI9y76BybFnOQP1m8Po7l4HoDxqa6no", + "39zc7KceCm1aG0c3DQ6zX7NCfL6W9mujXJvjXhCxu5h/RhLYg2nT0RKLk38Iz2sw2lB/ZeMS6OuorVV7", + "cFvTmVy8x7S9MkAotjHmK6PG+0QodWQL9KRr97NYnWv0DG7YdYmOc87FJxouzPpuUK+aNptSlu/jsP2X", + "sbCfqy9puCNRj+Ew6NcBtjmdkSUV2v2TnvlNq/z3WONC8Hq/zqOzZx52zSKO5MrrV7uusZnDytV3FnhA", + "bbOxuWZxC7loTlc8clujm01e7WiLW+0hY9N2wjL1Dx2gj+2Dhwulv9VLrr9LXfaS7y6zViANxfCfJOTt", + "QYyqwhYWtwrv9ReKc8UOOi5uv3X8fL+t9/5+vZYkdv1vk+t/toxfKNrUr0sobfYynb3H71/J+5MXW307", + "3d3v/hYAAP//wO3O5VcSAAA=", +} + +// GetSwagger returns the content of the embedded swagger specification file +// or error if failed to decode +func decodeSpec() ([]byte, error) { + zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, "")) + if err != nil { + return nil, fmt.Errorf("error base64 decoding spec: %s", err) + } + zr, err := gzip.NewReader(bytes.NewReader(zipped)) + if err != nil { + return nil, fmt.Errorf("error decompressing spec: %s", err) + } + var buf bytes.Buffer + _, err = buf.ReadFrom(zr) + if err != nil { + return nil, fmt.Errorf("error decompressing spec: %s", err) + } + + return buf.Bytes(), nil +} + +var rawSpec = decodeSpecCached() + +// a naive cached of a decoded swagger spec +func decodeSpecCached() func() ([]byte, error) { + data, err := decodeSpec() + return func() ([]byte, error) { + return data, err + } +} + +// Constructs a synthetic filesystem for resolving external references when loading openapi specifications. +func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) { + var res = make(map[string]func() ([]byte, error)) + if len(pathToFile) > 0 { + res[pathToFile] = rawSpec + } + + return res +} + +// GetSwagger returns the Swagger specification corresponding to the generated code +// in this file. The external references of Swagger specification are resolved. +// The logic of resolving external references is tightly connected to "import-mapping" feature. +// Externally referenced files must be embedded in the corresponding golang packages. +// Urls can be supported but this task was out of the scope. +func GetSwagger() (swagger *openapi3.T, err error) { + var resolvePath = PathToRawSpec("") + + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) { + var pathToFile = url.String() + pathToFile = path.Clean(pathToFile) + getSpec, ok := resolvePath[pathToFile] + if !ok { + err1 := fmt.Errorf("path not found: %s", pathToFile) + return nil, err1 + } + return getSpec() + } + var specData []byte + specData, err = rawSpec() + if err != nil { + return + } + swagger, err = loader.LoadFromData(specData) + if err != nil { + return + } + return +} diff --git a/examples/petstore-expanded/gorilla/api/petstore.go b/examples/petstore-expanded/gorilla/api/petstore.go new file mode 100644 index 000000000..ea8c34af6 --- /dev/null +++ b/examples/petstore-expanded/gorilla/api/petstore.go @@ -0,0 +1,128 @@ +//go:generate go run github.com/deepmap/oapi-codegen/cmd/oapi-codegen --config=cfg.yaml ../../petstore-expanded.yaml + +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" +) + +type PetStore struct { + Pets map[int64]Pet + NextId int64 + Lock sync.Mutex +} + +// Make sure we conform to ServerInterface + +var _ ServerInterface = (*PetStore)(nil) + +func NewPetStore() *PetStore { + return &PetStore{ + Pets: make(map[int64]Pet), + NextId: 1000, + } +} + +// This function wraps sending of an error in the Error format, and +// handling the failure to marshal that. +func sendPetstoreError(w http.ResponseWriter, code int, message string) { + petErr := Error{ + Code: int32(code), + Message: message, + } + w.WriteHeader(code) + json.NewEncoder(w).Encode(petErr) +} + +// Here, we implement all of the handlers in the ServerInterface +func (p *PetStore) FindPets(w http.ResponseWriter, r *http.Request, params FindPetsParams) { + p.Lock.Lock() + defer p.Lock.Unlock() + + var result []Pet + + for _, pet := range p.Pets { + if params.Tags != nil { + // If we have tags, filter pets by tag + for _, t := range *params.Tags { + if pet.Tag != nil && (*pet.Tag == t) { + result = append(result, pet) + } + } + } else { + // Add all pets if we're not filtering + result = append(result, pet) + } + + if params.Limit != nil { + l := int(*params.Limit) + if len(result) >= l { + // We're at the limit + break + } + } + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(result) +} + +func (p *PetStore) AddPet(w http.ResponseWriter, r *http.Request) { + // We expect a NewPet object in the request body. + var newPet NewPet + if err := json.NewDecoder(r.Body).Decode(&newPet); err != nil { + sendPetstoreError(w, http.StatusBadRequest, "Invalid format for NewPet") + return + } + + // We now have a pet, let's add it to our "database". + + // We're always asynchronous, so lock unsafe operations below + p.Lock.Lock() + defer p.Lock.Unlock() + + // We handle pets, not NewPets, which have an additional ID field + var pet Pet + pet.Name = newPet.Name + pet.Tag = newPet.Tag + pet.Id = p.NextId + p.NextId = p.NextId + 1 + + // Insert into map + p.Pets[pet.Id] = pet + + // Now, we have to return the NewPet + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(pet) +} + +func (p *PetStore) FindPetByID(w http.ResponseWriter, r *http.Request, id int64) { + p.Lock.Lock() + defer p.Lock.Unlock() + + pet, found := p.Pets[id] + if !found { + sendPetstoreError(w, http.StatusNotFound, fmt.Sprintf("Could not find pet with ID %d", id)) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(pet) +} + +func (p *PetStore) DeletePet(w http.ResponseWriter, r *http.Request, id int64) { + p.Lock.Lock() + defer p.Lock.Unlock() + + _, found := p.Pets[id] + if !found { + sendPetstoreError(w, http.StatusNotFound, fmt.Sprintf("Could not find pet with ID %d", id)) + return + } + delete(p.Pets, id) + + w.WriteHeader(http.StatusNoContent) +} diff --git a/examples/petstore-expanded/gorilla/petstore.go b/examples/petstore-expanded/gorilla/petstore.go new file mode 100644 index 000000000..b31239ba9 --- /dev/null +++ b/examples/petstore-expanded/gorilla/petstore.go @@ -0,0 +1,53 @@ +// This is an example of implementing the Pet Store from the OpenAPI documentation +// found at: +// https://github.com/OAI/OpenAPI-Specification/blob/master/examples/v3.0/petstore.yaml + +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "os" + + api "github.com/deepmap/oapi-codegen/examples/petstore-expanded/gorilla/api" + middleware "github.com/deepmap/oapi-codegen/pkg/chi-middleware" + "github.com/gorilla/mux" +) + +func main() { + var port = flag.Int("port", 8080, "Port for test HTTP server") + flag.Parse() + + swagger, err := api.GetSwagger() + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading swagger spec\n: %s", err) + os.Exit(1) + } + + // Clear out the servers array in the swagger spec, that skips validating + // that server names match. We don't know how this thing will be run. + swagger.Servers = nil + + // Create an instance of our handler which satisfies the generated interface + petStore := api.NewPetStore() + + // This is how you set up a basic Gorilla router + r := mux.NewRouter() + + // Use our validation middleware to check all requests against the + // OpenAPI schema. + r.Use(middleware.OapiRequestValidator(swagger)) + + // We now register our petStore above as the handler for the interface + api.HandlerFromMux(petStore, r) + + s := &http.Server{ + Handler: r, + Addr: fmt.Sprintf("0.0.0.0:%d", *port), + } + + // And we serve HTTP until the world ends. + log.Fatal(s.ListenAndServe()) +} diff --git a/examples/petstore-expanded/gorilla/petstore_test.go b/examples/petstore-expanded/gorilla/petstore_test.go new file mode 100644 index 000000000..c32532607 --- /dev/null +++ b/examples/petstore-expanded/gorilla/petstore_test.go @@ -0,0 +1,168 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/deepmap/oapi-codegen/examples/petstore-expanded/gorilla/api" + middleware "github.com/deepmap/oapi-codegen/pkg/chi-middleware" + "github.com/deepmap/oapi-codegen/pkg/testutil" +) + +func doGet(t *testing.T, mux *mux.Router, url string) *httptest.ResponseRecorder { + response := testutil.NewRequest().Get(url).WithAcceptJson().GoWithHTTPHandler(t, mux) + return response.Recorder +} + +func TestPetStore(t *testing.T) { + var err error + + // Get the swagger description of our API + swagger, err := api.GetSwagger() + require.NoError(t, err) + + // Clear out the servers array in the swagger spec, that skips validating + // that server names match. We don't know how this thing will be run. + swagger.Servers = nil + + // This is how you set up a basic Gorilla router + r := mux.NewRouter() + + // Use our validation middleware to check all requests against the + // OpenAPI schema. + r.Use(middleware.OapiRequestValidator(swagger)) + + store := api.NewPetStore() + api.HandlerFromMux(store, r) + + t.Run("Add pet", func(t *testing.T) { + tag := "TagOfSpot" + newPet := api.NewPet{ + Name: "Spot", + Tag: &tag, + } + + rr := testutil.NewRequest().Post("/pets").WithJsonBody(newPet).GoWithHTTPHandler(t, r).Recorder + assert.Equal(t, http.StatusCreated, rr.Code) + + var resultPet api.Pet + err = json.NewDecoder(rr.Body).Decode(&resultPet) + assert.NoError(t, err, "error unmarshaling response") + assert.Equal(t, newPet.Name, resultPet.Name) + assert.Equal(t, *newPet.Tag, *resultPet.Tag) + }) + + t.Run("Find pet by ID", func(t *testing.T) { + pet := api.Pet{ + Id: 100, + } + + store.Pets[pet.Id] = pet + rr := doGet(t, r, fmt.Sprintf("/pets/%d", pet.Id)) + + var resultPet api.Pet + err = json.NewDecoder(rr.Body).Decode(&resultPet) + assert.NoError(t, err, "error getting pet") + assert.Equal(t, pet, resultPet) + }) + + t.Run("Pet not found", func(t *testing.T) { + rr := doGet(t, r, "/pets/27179095781") + assert.Equal(t, http.StatusNotFound, rr.Code) + + var petError api.Error + err = json.NewDecoder(rr.Body).Decode(&petError) + assert.NoError(t, err, "error getting response", err) + assert.Equal(t, int32(http.StatusNotFound), petError.Code) + }) + + t.Run("List all pets", func(t *testing.T) { + store.Pets = map[int64]api.Pet{ + 1: api.Pet{}, + 2: api.Pet{}, + } + + // Now, list all pets, we should have two + rr := doGet(t, r, "/pets") + assert.Equal(t, http.StatusOK, rr.Code) + + var petList []api.Pet + err = json.NewDecoder(rr.Body).Decode(&petList) + assert.NoError(t, err, "error getting response", err) + assert.Equal(t, 2, len(petList)) + }) + + t.Run("Filter pets by tag", func(t *testing.T) { + tag := "TagOfFido" + + store.Pets = map[int64]api.Pet{ + 1: { + Tag: &tag, + }, + 2: {}, + } + + // Filter pets by tag, we should have 1 + rr := doGet(t, r, "/pets?tags=TagOfFido") + assert.Equal(t, http.StatusOK, rr.Code) + + var petList []api.Pet + err = json.NewDecoder(rr.Body).Decode(&petList) + assert.NoError(t, err, "error getting response", err) + assert.Equal(t, 1, len(petList)) + }) + + t.Run("Filter pets by tag", func(t *testing.T) { + store.Pets = map[int64]api.Pet{ + 1: api.Pet{}, + 2: api.Pet{}, + } + + // Filter pets by non existent tag, we should have 0 + rr := doGet(t, r, "/pets?tags=NotExists") + assert.Equal(t, http.StatusOK, rr.Code) + + var petList []api.Pet + err = json.NewDecoder(rr.Body).Decode(&petList) + assert.NoError(t, err, "error getting response", err) + assert.Equal(t, 0, len(petList)) + }) + + t.Run("Delete pets", func(t *testing.T) { + store.Pets = map[int64]api.Pet{ + 1: api.Pet{}, + 2: api.Pet{}, + } + + // Let's delete non-existent pet + rr := testutil.NewRequest().Delete("/pets/7").GoWithHTTPHandler(t, r).Recorder + assert.Equal(t, http.StatusNotFound, rr.Code) + + var petError api.Error + err = json.NewDecoder(rr.Body).Decode(&petError) + assert.NoError(t, err, "error unmarshaling PetError") + assert.Equal(t, int32(http.StatusNotFound), petError.Code) + + // Now, delete both real pets + rr = testutil.NewRequest().Delete("/pets/1").GoWithHTTPHandler(t, r).Recorder + assert.Equal(t, http.StatusNoContent, rr.Code) + + rr = testutil.NewRequest().Delete("/pets/2").GoWithHTTPHandler(t, r).Recorder + assert.Equal(t, http.StatusNoContent, rr.Code) + + // Should have no pets left. + var petList []api.Pet + rr = doGet(t, r, "/pets") + assert.Equal(t, http.StatusOK, rr.Code) + err = json.NewDecoder(rr.Body).Decode(&petList) + assert.NoError(t, err, "error getting response", err) + assert.Equal(t, 0, len(petList)) + }) +} diff --git a/go.mod b/go.mod index 7d46fee5b..6e2b42be6 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219 github.com/google/uuid v1.3.0 + github.com/gorilla/mux v1.8.0 github.com/json-iterator/go v1.1.12 // indirect github.com/labstack/echo/v4 v4.7.2 github.com/lestrrat-go/blackmagic v1.0.1 // indirect