Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow different param names in different methods with same path scheme #2209

Merged
merged 5 commits into from Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
197 changes: 100 additions & 97 deletions router.go
Expand Up @@ -19,30 +19,34 @@ type (
prefix string
parent *node
staticChildren children
ppath string
pnames []string
methodHandler *methodHandler
methods *routeMethods
paramChild *node
anyChild *node
paramsCount int
// isLeaf indicates that node does not have child routes
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool
}
kind uint8
children []*node
methodHandler struct {
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
kind uint8
children []*node
routeMethod struct {
ppath string
pnames []string
handler HandlerFunc
}
routeMethods struct {
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
allowHeader string
}
)
Expand All @@ -56,7 +60,7 @@ const (
anyLabel = byte('*')
)

func (m *methodHandler) isHandler() bool {
func (m *routeMethods) isHandler() bool {
return m.connect != nil ||
m.delete != nil ||
m.get != nil ||
Expand All @@ -70,7 +74,7 @@ func (m *methodHandler) isHandler() bool {
m.report != nil
}

func (m *methodHandler) updateAllowHeader() {
func (m *routeMethods) updateAllowHeader() {
buf := new(bytes.Buffer)
buf.WriteString(http.MethodOptions)

Expand Down Expand Up @@ -119,7 +123,7 @@ func (m *methodHandler) updateAllowHeader() {
func NewRouter(e *Echo) *Router {
return &Router{
tree: &node{
methodHandler: new(methodHandler),
methods: new(routeMethods),
},
routes: map[string]*Route{},
echo: e,
Expand Down Expand Up @@ -153,7 +157,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
}
j := i + 1

r.insert(method, path[:i], nil, staticKind, "", nil)
r.insert(method, path[:i], staticKind, routeMethod{})
for ; i < lcpIndex && path[i] != '/'; i++ {
}

Expand All @@ -163,23 +167,23 @@ func (r *Router) Add(method, path string, h HandlerFunc) {

if i == lcpIndex {
// path node is last fragment of route path. ie. `/users/:id`
r.insert(method, path[:i], h, paramKind, ppath, pnames)
r.insert(method, path[:i], paramKind, routeMethod{ppath, pnames, h})
} else {
r.insert(method, path[:i], nil, paramKind, "", nil)
r.insert(method, path[:i], paramKind, routeMethod{})
}
} else if path[i] == '*' {
r.insert(method, path[:i], nil, staticKind, "", nil)
r.insert(method, path[:i], staticKind, routeMethod{})
pnames = append(pnames, "*")
r.insert(method, path[:i+1], h, anyKind, ppath, pnames)
r.insert(method, path[:i+1], anyKind, routeMethod{ppath, pnames, h})
}
}

r.insert(method, path, h, staticKind, ppath, pnames)
r.insert(method, path, staticKind, routeMethod{ppath, pnames, h})
}

func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) {
func (r *Router) insert(method, path string, t kind, rm routeMethod) {
// Adjust max param
paramLen := len(pnames)
paramLen := len(rm.pnames)
if *r.echo.maxParam < paramLen {
*r.echo.maxParam = paramLen
}
Expand Down Expand Up @@ -207,11 +211,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
// At root node
currentNode.label = search[0]
currentNode.prefix = search
if h != nil {
if rm.handler != nil {
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
currentNode.addMethod(method, &rm)
currentNode.paramsCount = len(rm.pnames)
}
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen {
Expand All @@ -221,9 +224,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix[lcpLen:],
currentNode,
currentNode.staticChildren,
currentNode.methodHandler,
currentNode.ppath,
currentNode.pnames,
currentNode.methods,
currentNode.paramsCount,
currentNode.paramChild,
currentNode.anyChild,
)
Expand All @@ -243,9 +245,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.label = currentNode.prefix[0]
currentNode.prefix = currentNode.prefix[:lcpLen]
currentNode.staticChildren = nil
currentNode.methodHandler = new(methodHandler)
currentNode.ppath = ""
currentNode.pnames = nil
currentNode.methods = new(routeMethods)
currentNode.paramsCount = 0
currentNode.paramChild = nil
currentNode.anyChild = nil
currentNode.isLeaf = false
Expand All @@ -257,13 +258,17 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
if lcpLen == searchLen {
// At parent node
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
if rm.handler != nil {
currentNode.addMethod(method, &rm)
currentNode.paramsCount = len(rm.pnames)
}
} else {
// Create child node
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n = newNode(t, search[lcpLen:], currentNode, nil, new(routeMethods), 0, nil, nil)
if rm.handler != nil {
n.addMethod(method, &rm)
n.paramsCount = len(rm.pnames)
}
// Only Static children could reach here
currentNode.addStaticChild(n)
}
Expand All @@ -277,8 +282,12 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue
}
// Create child node
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n := newNode(t, search, currentNode, nil, new(routeMethods), 0, nil, nil)
if rm.handler != nil {
n.addMethod(method, &rm)
n.paramsCount = len(rm.pnames)
}

switch t {
case staticKind:
currentNode.addStaticChild(n)
Expand All @@ -290,28 +299,24 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else {
// Node already exists
if h != nil {
currentNode.addHandler(method, h)
currentNode.ppath = ppath
if len(currentNode.pnames) == 0 { // Issue #729
currentNode.pnames = pnames
}
if rm.handler != nil {
currentNode.addMethod(method, &rm)
currentNode.paramsCount = len(rm.pnames)
}
}
return
}
}

func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node {
func newNode(t kind, pre string, p *node, sc children, mh *routeMethods, paramsCount int, paramChildren, anyChildren *node) *node {
return &node{
kind: t,
label: pre[0],
prefix: pre,
parent: p,
staticChildren: sc,
ppath: ppath,
pnames: pnames,
methodHandler: mh,
methods: mh,
paramsCount: paramsCount,
paramChild: paramChildren,
anyChild: anyChildren,
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
Expand Down Expand Up @@ -345,64 +350,60 @@ func (n *node) findChildWithLabel(l byte) *node {
return nil
}

func (n *node) addHandler(method string, h HandlerFunc) {
func (n *node) addMethod(method string, h *routeMethod) {
switch method {
case http.MethodConnect:
n.methodHandler.connect = h
n.methods.connect = h
case http.MethodDelete:
n.methodHandler.delete = h
n.methods.delete = h
case http.MethodGet:
n.methodHandler.get = h
n.methods.get = h
case http.MethodHead:
n.methodHandler.head = h
n.methods.head = h
case http.MethodOptions:
n.methodHandler.options = h
n.methods.options = h
case http.MethodPatch:
n.methodHandler.patch = h
n.methods.patch = h
case http.MethodPost:
n.methodHandler.post = h
n.methods.post = h
case PROPFIND:
n.methodHandler.propfind = h
n.methods.propfind = h
case http.MethodPut:
n.methodHandler.put = h
n.methods.put = h
case http.MethodTrace:
n.methodHandler.trace = h
n.methods.trace = h
case REPORT:
n.methodHandler.report = h
n.methods.report = h
}

n.methodHandler.updateAllowHeader()
if h != nil {
n.isHandler = true
} else {
n.isHandler = n.methodHandler.isHandler()
}
n.methods.updateAllowHeader()
n.isHandler = true
}

func (n *node) findHandler(method string) HandlerFunc {
func (n *node) findMethod(method string) *routeMethod {
switch method {
case http.MethodConnect:
return n.methodHandler.connect
return n.methods.connect
case http.MethodDelete:
return n.methodHandler.delete
return n.methods.delete
case http.MethodGet:
return n.methodHandler.get
return n.methods.get
case http.MethodHead:
return n.methodHandler.head
return n.methods.head
case http.MethodOptions:
return n.methodHandler.options
return n.methods.options
case http.MethodPatch:
return n.methodHandler.patch
return n.methods.patch
case http.MethodPost:
return n.methodHandler.post
return n.methods.post
case PROPFIND:
return n.methodHandler.propfind
return n.methods.propfind
case http.MethodPut:
return n.methodHandler.put
return n.methods.put
case http.MethodTrace:
return n.methodHandler.trace
return n.methods.trace
case REPORT:
return n.methodHandler.report
return n.methods.report
default:
return nil
}
Expand Down Expand Up @@ -433,7 +434,7 @@ func (r *Router) Find(method, path string, c Context) {

var (
previousBestMatchNode *node
matchedHandler HandlerFunc
matchedRouteMethod *routeMethod
// search stores the remaining path to check for match. By each iteration we move from start of path to end of the path
// and search value gets shorter and shorter.
search = path
Expand Down Expand Up @@ -529,8 +530,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
if h := currentNode.findMethod(method); h != nil {
matchedRouteMethod = h
break
}
}
Expand Down Expand Up @@ -569,7 +570,8 @@ func (r *Router) Find(method, path string, c Context) {
if child := currentNode.anyChild; child != nil {
// If any node is found, use remaining path for paramValues
currentNode = child
paramValues[len(currentNode.pnames)-1] = search
paramValues[currentNode.paramsCount-1] = search

// update indexes/search in case we need to backtrack when no handler match is found
paramIndex++
searchIndex += +len(search)
Expand All @@ -580,8 +582,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
if h := currentNode.findMethod(method); h != nil {
matchedRouteMethod = h
break
}
}
Expand All @@ -604,22 +606,23 @@ func (r *Router) Find(method, path string, c Context) {
return // nothing matched at all
}

if matchedHandler != nil {
ctx.handler = matchedHandler
if matchedRouteMethod != nil {
ctx.handler = matchedRouteMethod.handler
ctx.path = matchedRouteMethod.ppath
ctx.pnames = matchedRouteMethod.pnames
} else {
// use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode

ctx.path = path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below. This is incorrect as path must be path of the route not path of the request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

ctx.handler = NotFoundHandler
if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
ctx.handler = optionsMethodHandler(currentNode.methods.allowHeader)
}
}
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames
aldas marked this conversation as resolved.
Show resolved Hide resolved
}