diff --git a/routers/gorillamux/router.go b/routers/gorillamux/router.go index 811ba7d16..6977808ed 100644 --- a/routers/gorillamux/router.go +++ b/routers/gorillamux/router.go @@ -57,8 +57,6 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) { muxRouter := mux.NewRouter().UseEncodedPath() r := &Router{} for _, path := range orderedPaths(doc.Paths) { - servers := servers - pathItem := doc.Paths[path] if len(pathItem.Servers) > 0 { if servers, err = makeServers(pathItem.Servers); err != nil { @@ -140,19 +138,13 @@ func makeServers(in openapi3.Servers) ([]srv, error) { if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" { varsUpdater = func(vars map[string]string) { vars[sVar] = lhs } } - servers = append(servers, srv{ - base: server.Variables[sVar].Default, - server: server, - varsUpdater: varsUpdater, - }) - continue - } + svr, err := newSrv(serverURL, server, varsUpdater) + if err != nil { + return nil, err + } - var schemes []string - if strings.Contains(serverURL, "://") { - scheme0 := strings.Split(serverURL, "://")[0] - schemes = permutePart(scheme0, server) - serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1) + servers = append(servers, svr) + continue } // If a variable represents the port "http://domain.tld:{port}/bla" @@ -172,21 +164,11 @@ func makeServers(in openapi3.Servers) ([]srv, error) { } } - u, err := url.Parse(bEncode(serverURL)) + svr, err := newSrv(serverURL, server, varsUpdater) if err != nil { return nil, err } - path := bDecode(u.EscapedPath()) - if len(path) > 0 && path[len(path)-1] == '/' { - path = path[:len(path)-1] - } - servers = append(servers, srv{ - host: bDecode(u.Host), //u.Hostname()? - base: path, - schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 - server: server, - varsUpdater: varsUpdater, - }) + servers = append(servers, svr) } if len(servers) == 0 { servers = append(servers, srv{}) @@ -195,6 +177,32 @@ func makeServers(in openapi3.Servers) ([]srv, error) { return servers, nil } +func newSrv(serverURL string, server *openapi3.Server, varsUpdater varsf) (srv, error) { + var schemes []string + if strings.Contains(serverURL, "://") { + scheme0 := strings.Split(serverURL, "://")[0] + schemes = permutePart(scheme0, server) + serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1) + } + + u, err := url.Parse(bEncode(serverURL)) + if err != nil { + return srv{}, err + } + path := bDecode(u.EscapedPath()) + if len(path) > 0 && path[len(path)-1] == '/' { + path = path[:len(path)-1] + } + svr := srv{ + host: bDecode(u.Host), //u.Hostname()? + base: path, + schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 + server: server, + varsUpdater: varsUpdater, + } + return svr, nil +} + func orderedPaths(paths map[string]*openapi3.PathItem) []string { // https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject // When matching URLs, concrete (non-templated) paths would be matched diff --git a/routers/gorillamux/router_test.go b/routers/gorillamux/router_test.go index 104056e18..3e7440063 100644 --- a/routers/gorillamux/router_test.go +++ b/routers/gorillamux/router_test.go @@ -6,6 +6,7 @@ import ( "sort" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/getkin/kin-openapi/openapi3" @@ -249,7 +250,16 @@ func TestServerPath(t *testing.T) { "http://example.com:{port}/path", map[string]string{ "port": "8088", - })}, + }), + newServerWithVariables( + "{server}", + map[string]string{ + "server": "/", + }), + newServerWithVariables( + "/", + nil, + )}, }) require.NoError(t, err) } @@ -325,6 +335,157 @@ func TestRelativeURL(t *testing.T) { require.Equal(t, "/hello", route.Path) } +func Test_makeServers(t *testing.T) { + type testStruct struct { + name string + servers openapi3.Servers + want []srv + wantErr bool + initFn func(tt *testStruct) + } + tests := []testStruct{ + { + name: "server is root path", + servers: openapi3.Servers{ + newServerWithVariables("/", nil), + }, + want: []srv{{ + schemes: nil, + host: "", + base: "", + server: nil, + varsUpdater: nil, + }}, + wantErr: false, + initFn: func(tt *testStruct) { + for i, server := range tt.servers { + tt.want[i].server = server + } + }, + }, + { + name: "server with single variable that evaluates to root path", + servers: openapi3.Servers{ + newServerWithVariables("{server}", map[string]string{"server": "/"}), + }, + want: []srv{{ + schemes: nil, + host: "", + base: "", + server: nil, + varsUpdater: nil, + }}, + wantErr: false, + initFn: func(tt *testStruct) { + for i, server := range tt.servers { + tt.want[i].server = server + } + }, + }, + { + name: "server is http://localhost:28002", + servers: openapi3.Servers{ + newServerWithVariables("http://localhost:28002", nil), + }, + want: []srv{{ + schemes: []string{"http"}, + host: "localhost:28002", + base: "", + server: nil, + varsUpdater: nil, + }}, + wantErr: false, + initFn: func(tt *testStruct) { + for i, server := range tt.servers { + tt.want[i].server = server + } + }, + }, + { + name: "server with single variable that evaluates to http://localhost:28002", + servers: openapi3.Servers{ + newServerWithVariables("{server}", map[string]string{"server": "http://localhost:28002"}), + }, + want: []srv{{ + schemes: []string{"http"}, + host: "localhost:28002", + base: "", + server: nil, + varsUpdater: nil, + }}, + wantErr: false, + initFn: func(tt *testStruct) { + for i, server := range tt.servers { + tt.want[i].server = server + } + }, + }, + { + name: "server with multiple variables that evaluates to http://localhost:28002", + servers: openapi3.Servers{ + newServerWithVariables("{scheme}://{host}:{port}", map[string]string{"scheme": "http", "host": "localhost", "port": "28002"}), + }, + want: []srv{{ + schemes: []string{"http"}, + host: "{host}:28002", + base: "", + server: nil, + varsUpdater: func(vars map[string]string) { vars["port"] = "28002" }, + }}, + wantErr: false, + initFn: func(tt *testStruct) { + for i, server := range tt.servers { + tt.want[i].server = server + } + }, + }, + { + name: "server with unparsable URL fails", + servers: openapi3.Servers{ + newServerWithVariables("exam^ple.com:443", nil), + }, + want: nil, + wantErr: true, + initFn: nil, + }, + { + name: "server with single variable that evaluates to unparsable URL fails", + servers: openapi3.Servers{ + newServerWithVariables("{server}", map[string]string{"server": "exam^ple.com:443"}), + }, + want: nil, + wantErr: true, + initFn: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.initFn != nil { + tt.initFn(&tt) + } + got, err := makeServers(tt.servers) + if (err != nil) != tt.wantErr { + t.Errorf("makeServers() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, len(tt.want), len(got), "expected and actual servers lengths are not equal") + for i := 0; i < len(tt.want); i++ { + // Unfortunately using assert.Equals or reflect.DeepEquals isn't + // an option because function pointers cannot be compared + assert.Equal(t, tt.want[i].schemes, got[i].schemes) + assert.Equal(t, tt.want[i].host, got[i].host) + assert.Equal(t, tt.want[i].host, got[i].host) + assert.Equal(t, tt.want[i].server, got[i].server) + if tt.want[i].varsUpdater == nil { + assert.Nil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function") + } else { + assert.NotNil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function") + } + } + }) + } +} + func newServerWithVariables(url string, variables map[string]string) *openapi3.Server { var serverVariables = map[string]*openapi3.ServerVariable{}