Skip to content

Commit 071a686

Browse files
authored
Merge pull request #15 from vardius/hotfix/middleware-by-path
Fix middleware for wildcard routes
2 parents 0d328e6 + 7eb5db7 commit 071a686

23 files changed

+943
-590
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,8 @@
1414
.glide/
1515

1616
.vscode
17+
.idea
1718

18-
vendor/
19+
vendor/
20+
21+
.history/

README.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,23 @@ import (
4242
"github.com/vardius/gorouter/v4/context"
4343
)
4444

45-
func Index(w http.ResponseWriter, r *http.Request) {
46-
fmt.Fprint(w, "Welcome!\n")
45+
func index(w http.ResponseWriter, _ *http.Request) {
46+
if _, err := fmt.Fprint(w, "Welcome!\n"); err != nil {
47+
panic(err)
48+
}
4749
}
4850

49-
func Hello(w http.ResponseWriter, r *http.Request) {
51+
func hello(w http.ResponseWriter, r *http.Request) {
5052
params, _ := context.Parameters(r.Context())
51-
fmt.Fprintf(w, "hello, %s!\n", params.Value("name"))
53+
if _, err := fmt.Fprintf(w, "hello, %s!\n", params.Value("name")); err != nil {
54+
panic(err)
55+
}
5256
}
5357

5458
func main() {
5559
router := gorouter.New()
56-
router.GET("/", http.HandlerFunc(Index))
57-
router.GET("/hello/{name}", http.HandlerFunc(Hello))
60+
router.GET("/", http.HandlerFunc(index))
61+
router.GET("/hello/{name}", http.HandlerFunc(hello))
5862

5963
log.Fatal(http.ListenAndServe(":8080", router))
6064
}
@@ -71,19 +75,19 @@ import (
7175
"github.com/vardius/gorouter/v4"
7276
)
7377

74-
func Index(ctx *fasthttp.RequestCtx) {
78+
func index(_ *fasthttp.RequestCtx) {
7579
fmt.Print("Welcome!\n")
7680
}
7781

78-
func Hello(ctx *fasthttp.RequestCtx) {
82+
func hello(ctx *fasthttp.RequestCtx) {
7983
params := ctx.UserValue("params").(context.Params)
8084
fmt.Printf("Hello, %s!\n", params.Value("name"))
8185
}
8286

8387
func main() {
8488
router := gorouter.NewFastHTTPRouter()
85-
router.GET("/", Index)
86-
router.GET("/hello/{name}", Hello)
89+
router.GET("/", index)
90+
router.GET("/hello/{name}", hello)
8791

8892
log.Fatal(fasthttp.ListenAndServe(":8080", router.HandleFastHTTP))
8993
}

doc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Package gorouter provide request router with middleware
33
44
Router
55
6-
The router determines how to handle that request.
6+
The router determines how to handle http request.
77
GoRouter uses a routing tree. Once one branch of the tree matches, only routes inside that branch are considered,
88
not any routes after that branch. When instantiating router, the root node of tree is created.
99
@@ -31,7 +31,7 @@ A full route definition contain up to three parts:
3131
2. The URL path route. This is matched against the URL passed to the router,
3232
and can contain named wildcard placeholders *(e.g. {placeholder})* to match dynamic parts in the URL.
3333
34-
3. `http.HandleFunc`, which tells the router to handle matched requests to the router with handler.
34+
3. `http.HandlerFunc`, which tells the router to handle matched requests to the router with handler.
3535
3636
Take the following example:
3737

example_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func ExampleMiddlewareFunc_second() {
114114
router := gorouter.New()
115115
router.GET("/hello/{name}", http.HandlerFunc(hello))
116116

117-
// apply middleware to route and all it children
117+
// apply middleware to route and all its children
118118
// can pass as many as you want
119119
router.USE("GET", "/hello/{name}", logger)
120120

@@ -206,7 +206,7 @@ func ExampleFastHTTPMiddlewareFunc_second() {
206206
router := gorouter.NewFastHTTPRouter()
207207
router.GET("/hello/{name}", hello)
208208

209-
// apply middleware to route and all it children
209+
// apply middleware to route and all its children
210210
// can pass as many as you want
211211
router.USE("GET", "/hello/{name}", logger)
212212

fasthttp.go

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,29 @@ import (
66
"github.com/valyala/fasthttp"
77
"github.com/vardius/gorouter/v4/middleware"
88
"github.com/vardius/gorouter/v4/mux"
9-
pathutils "github.com/vardius/gorouter/v4/path"
109
)
1110

1211
// NewFastHTTPRouter creates new Router instance, returns pointer
1312
func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter {
13+
globalMiddleware := transformFastHTTPMiddlewareFunc(fs...)
1414
return &fastHTTPRouter{
15-
routes: mux.NewTree(),
16-
middleware: transformFastHTTPMiddlewareFunc(fs...),
15+
tree: mux.NewTree(),
16+
globalMiddleware: globalMiddleware,
17+
middlewareCounter: uint(len(globalMiddleware)),
1718
}
1819
}
1920

2021
type fastHTTPRouter struct {
21-
routes mux.Tree
22-
middleware middleware.Middleware
23-
fileServer fasthttp.RequestHandler
24-
notFound fasthttp.RequestHandler
25-
notAllowed fasthttp.RequestHandler
22+
tree mux.Tree
23+
globalMiddleware middleware.Collection
24+
fileServer fasthttp.RequestHandler
25+
notFound fasthttp.RequestHandler
26+
notAllowed fasthttp.RequestHandler
27+
middlewareCounter uint
2628
}
2729

2830
func (r *fastHTTPRouter) PrettyPrint() string {
29-
return r.routes.PrettyPrint()
31+
return r.tree.PrettyPrint()
3032
}
3133

3234
func (r *fastHTTPRouter) POST(p string, f fasthttp.RequestHandler) {
@@ -65,17 +67,20 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) {
6567
r.Handle(http.MethodTrace, p, f)
6668
}
6769

68-
func (r *fastHTTPRouter) USE(method, p string, fs ...FastHTTPMiddlewareFunc) {
70+
func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) {
6971
m := transformFastHTTPMiddlewareFunc(fs...)
72+
for i, mf := range m {
73+
m[i] = middleware.WithPriority(mf, r.middlewareCounter)
74+
}
7075

71-
addMiddleware(r.routes, method, p, m)
76+
r.tree = r.tree.WithMiddleware(method+path, m, 0)
77+
r.middlewareCounter += uint(len(m))
7278
}
7379

7480
func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) {
7581
route := newRoute(h)
76-
route.PrependMiddleware(r.middleware)
7782

78-
r.routes = r.routes.WithRoute(method+path, route, 0)
83+
r.tree = r.tree.WithRoute(method+path, route, 0)
7984
}
8085

8186
func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) {
@@ -91,15 +96,14 @@ func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) {
9196
http.MethodTrace,
9297
} {
9398
route := newRoute(h)
94-
route.PrependMiddleware(r.middleware)
9599

96-
r.routes = r.routes.WithSubrouter(method+path, route, 0)
100+
r.tree = r.tree.WithSubrouter(method+path, route, 0)
97101
}
98102
}
99103

100104
func (r *fastHTTPRouter) Compile() {
101-
for i, methodNode := range r.routes {
102-
r.routes[i].WithChildren(methodNode.Tree().Compile())
105+
for i, methodNode := range r.tree {
106+
r.tree[i].WithChildren(methodNode.Tree().Compile())
103107
}
104108
}
105109

@@ -121,32 +125,38 @@ func (r *fastHTTPRouter) ServeFiles(root string, stripSlashes int) {
121125

122126
func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
123127
method := string(ctx.Method())
124-
pathAsString := string(ctx.Path())
125-
path := pathutils.TrimSlash(pathAsString)
126-
127-
if root := r.routes.Find(method); root != nil {
128-
if node, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil {
129-
if len(params) > 0 {
130-
ctx.SetUserValue("params", params)
128+
path := string(ctx.Path())
129+
130+
if route, params, subPath := r.tree.MatchRoute(method + path); route != nil {
131+
var h fasthttp.RequestHandler
132+
if r.middlewareCounter > 0 {
133+
allMiddleware := r.globalMiddleware
134+
if treeMiddleware := r.tree.MatchMiddleware(method + path); len(treeMiddleware) > 0 {
135+
allMiddleware = allMiddleware.Merge(treeMiddleware.Sort())
131136
}
132137

133-
if subPath != "" {
134-
ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx))
135-
}
138+
computedHandler := allMiddleware.Compose(route.Handler())
136139

137-
node.Route().Handler().(fasthttp.RequestHandler)(ctx)
138-
return
140+
h = computedHandler.(fasthttp.RequestHandler)
141+
} else {
142+
h = route.Handler().(fasthttp.RequestHandler)
139143
}
140144

141-
if pathAsString == "/" && root.Route() != nil {
142-
root.Route().Handler().(fasthttp.RequestHandler)(ctx)
143-
return
145+
if len(params) > 0 {
146+
ctx.SetUserValue("params", params)
144147
}
148+
149+
if subPath != "" {
150+
ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx))
151+
}
152+
153+
h(ctx)
154+
return
145155
}
146156

147157
// Handle OPTIONS
148158
if method == http.MethodOptions {
149-
if allow := allowed(r.routes, method, path); len(allow) > 0 {
159+
if allow := allowed(r.tree, method, path); len(allow) > 0 {
150160
ctx.Response.Header.Set("Allow", allow)
151161
return
152162
}
@@ -156,7 +166,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
156166
return
157167
} else {
158168
// Handle 405
159-
if allow := allowed(r.routes, method, path); len(allow) > 0 {
169+
if allow := allowed(r.tree, method, path); len(allow) > 0 {
160170
ctx.Response.Header.Set("Allow", allow)
161171
r.serveNotAllowed(ctx)
162172
return
@@ -183,12 +193,12 @@ func (r *fastHTTPRouter) serveNotAllowed(ctx *fasthttp.RequestCtx) {
183193
}
184194
}
185195

186-
func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Middleware {
187-
m := make(middleware.Middleware, len(fs))
196+
func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Collection {
197+
m := make(middleware.Collection, len(fs))
188198

189199
for i, f := range fs {
190-
m[i] = func(mf FastHTTPMiddlewareFunc) middleware.MiddlewareFunc {
191-
return func(h interface{}) interface{} {
200+
m[i] = func(mf FastHTTPMiddlewareFunc) middleware.WrapperFunc {
201+
return func(h middleware.Handler) middleware.Handler {
192202
return mf(h.(fasthttp.RequestHandler))
193203
}
194204
}(f) // f is a reference to function so we have to wrap if with that callback

0 commit comments

Comments
 (0)