-
-
Notifications
You must be signed in to change notification settings - Fork 410
/
router.go
210 lines (196 loc) · 5.55 KB
/
router.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Package gorillamux implements a router.
//
// It differs from the legacy router:
// * it provides somewhat granular errors: "path not found", "method not allowed".
// * it handles matching routes with extensions (e.g. /books/{id}.json)
// * it handles path patterns with a different syntax (e.g. /params/{x}/{y}/{z:.*})
package gorillamux
import (
"net/http"
"net/url"
"sort"
"strings"
"github.com/gorilla/mux"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/routers"
)
var _ routers.Router = &Router{}
// Router helps link http.Request.s and an OpenAPIv3 spec
type Router struct {
muxes []*mux.Route
routes []*routers.Route
}
// NewRouter creates a gorilla/mux router.
// Assumes spec is .Validate()d
// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))
func NewRouter(doc *openapi3.T) (routers.Router, error) {
type srv struct {
schemes []string
host, base string
server *openapi3.Server
}
servers := make([]srv, 0, len(doc.Servers))
for _, server := range doc.Servers {
serverURL := server.URL
var schemes []string
var u *url.URL
var err error
if strings.Contains(serverURL, "://") {
scheme0 := strings.Split(serverURL, "://")[0]
schemes = permutePart(scheme0, server)
u, err = url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)))
} else {
u, err = url.Parse(bEncode(serverURL))
}
if err != nil {
return nil, err
}
path := bDecode(u.EscapedPath())
if len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
servers = append(servers, srv{
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
})
}
if len(servers) == 0 {
servers = append(servers, srv{})
}
muxRouter := mux.NewRouter().UseEncodedPath()
r := &Router{}
for _, path := range orderedPaths(doc.Paths) {
pathItem := doc.Paths[path]
operations := pathItem.Operations()
methods := make([]string, 0, len(operations))
for method := range operations {
methods = append(methods, method)
}
sort.Strings(methods)
for _, s := range servers {
muxRoute := muxRouter.Path(s.base + path).Methods(methods...)
if schemes := s.schemes; len(schemes) != 0 {
muxRoute.Schemes(schemes...)
}
if host := s.host; host != "" {
muxRoute.Host(host)
}
if err := muxRoute.GetError(); err != nil {
return nil, err
}
r.muxes = append(r.muxes, muxRoute)
r.routes = append(r.routes, &routers.Route{
Spec: doc,
Server: s.server,
Path: path,
PathItem: pathItem,
Method: "",
Operation: nil,
})
}
}
return r, nil
}
// FindRoute extracts the route and parameters of an http.Request
func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string, error) {
for i, muxRoute := range r.muxes {
var match mux.RouteMatch
if muxRoute.Match(req, &match) {
if err := match.MatchErr; err != nil {
// What then?
}
route := *r.routes[i]
route.Method = req.Method
route.Operation = route.Spec.Paths[route.Path].GetOperation(route.Method)
return &route, match.Vars, nil
}
switch match.MatchErr {
case nil:
case mux.ErrMethodMismatch:
return nil, nil, routers.ErrMethodNotAllowed
default: // What then?
}
}
return nil, nil, routers.ErrPathNotFound
}
func orderedPaths(paths map[string]*openapi3.PathItem) []string {
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject
// When matching URLs, concrete (non-templated) paths would be matched
// before their templated counterparts.
// NOTE: sorting by number of variables ASC then by descending lexicographical
// order seems to be a good heuristic.
vars := make(map[int][]string)
max := 0
for path := range paths {
count := strings.Count(path, "}")
vars[count] = append(vars[count], path)
if count > max {
max = count
}
}
ordered := make([]string, 0, len(paths))
for c := 0; c <= max; c++ {
if ps, ok := vars[c]; ok {
sort.Sort(sort.Reverse(sort.StringSlice(ps)))
ordered = append(ordered, ps...)
}
}
return ordered
}
// Magic strings that temporarily replace "{}" so net/url.Parse() works
var blURL, brURL = strings.Repeat("-", 50), strings.Repeat("_", 50)
func bEncode(s string) string {
s = strings.Replace(s, "{", blURL, -1)
s = strings.Replace(s, "}", brURL, -1)
return s
}
func bDecode(s string) string {
s = strings.Replace(s, blURL, "{", -1)
s = strings.Replace(s, brURL, "}", -1)
return s
}
func permutePart(part0 string, srv *openapi3.Server) []string {
type mapAndSlice struct {
m map[string]struct{}
s []string
}
var2val := make(map[string]mapAndSlice)
max := 0
for name0, v := range srv.Variables {
name := "{" + name0 + "}"
if !strings.Contains(part0, name) {
continue
}
m := map[string]struct{}{v.Default: {}}
for _, value := range v.Enum {
m[value] = struct{}{}
}
if l := len(m); l > max {
max = l
}
s := make([]string, 0, len(m))
for value := range m {
s = append(s, value)
}
var2val[name] = mapAndSlice{m: m, s: s}
}
if len(var2val) == 0 {
return []string{part0}
}
partsMap := make(map[string]struct{}, max*len(var2val))
for i := 0; i < max; i++ {
part := part0
for name, mas := range var2val {
part = strings.Replace(part, name, mas.s[i%len(mas.s)], -1)
}
partsMap[part] = struct{}{}
}
parts := make([]string, 0, len(partsMap))
for part := range partsMap {
parts = append(parts, part)
}
sort.Strings(parts)
return parts
}