-
Notifications
You must be signed in to change notification settings - Fork 0
/
mux.go
126 lines (100 loc) · 4.44 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package main
import (
"compress/gzip"
"log"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/NYTimes/gziphandler"
"github.com/didip/tollbooth"
"github.com/didip/tollbooth/limiter"
"github.com/rs/cors"
"github.com/sirupsen/logrus"
"github.com/thisendout/apollo"
)
func buildServeMux(rootChain apollo.Chain, settings GatewaySettings) *http.ServeMux {
// build the serve mux (the collection of handler functions)
mux := http.NewServeMux()
for _, api := range settings.APIConfigs {
customChain := rootChain
// add rate limiter
if api.RateLimitPerSecond != 0 {
rl := api.RateLimitPerSecond
// build a rate limiter middleware function
// by default it keys the limiter on the following headers: "RemoteAddr", "X-Forwarded-For", "X-Real-IP"
// create an X request/second limiter and every token bucket in it will expire 1 hour after it was initially set.
lmt := tollbooth.NewLimiter(rl, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour})
// trigger a custom function with some logging info when the limit is reached
lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) {
logger.WithFields(logrus.Fields{
"remote_ip": r.RemoteAddr,
"url": r.URL.String(),
"max_per_second": rl,
}).Info("rate limit exceeded")
return
})
// We override the default headers that are inspected to avoid that requests coming from CloudFlare (or some other CDN)
// edge servers are considered as coming from the same user.
// We avoid this by setting the RemoteAddr lookup last. This ensures it is only used when none of the other headers are available.
lmt.SetIPLookups([]string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"})
// Make sure the HTTP method is involved in the generation of the rate limit key so
// the CORS preflight OPTIONS requests do not trigger a rate limit for the subsequent real request.
lmt.SetMethods([]string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodConnect,
http.MethodOptions,
http.MethodTrace,
})
// wrap the rate limiter for use in Apollo chains
wrappedRateLimiter := func(next http.Handler) http.Handler { return tollbooth.LimitHandler(lmt, next) }
// add the rate limiter to the main chain
customChain = customChain.Append(apollo.Wrap(wrappedRateLimiter))
}
// Add the access logging middleware
customChain = customChain.Append(apollo.Wrap(NewLoggerMiddleware(logger)))
// add handling of CORS-related preflight requests.
if api.CORS != nil {
corsHandler := cors.New(api.CORS.ToConfig())
// TODO: default logger is way too verbose. Leave it to our own access Logger.
// // set a logger for the corsHandler.
// // Derive this logger from logrus. We consider all these log entries 'INFO'.
// corsLogger := logger.WriterLevel(logrus.InfoLevel)
// corsHandler.Log = log.New(corsLogger, "cors_preflight--", 0)
customChain = customChain.Append(apollo.Wrap(corsHandler.Handler))
}
// add auth middleware if required
if api.Auth {
customChain = customChain.Append(apollo.Wrap(NewJWTMiddleware(settings.requestAuthenticator)))
}
// add the gzip middleware if required
if api.Gzip {
gz := gziphandler.MustNewGzipLevelHandler(gzip.DefaultCompression)
customChain = customChain.Append(apollo.Wrap(gz))
}
if api.StripPrefix {
// Apollo provides a Wrap function to inject normal http.Handler-based middleware into the chain.
// The context will skip over the injected middleware and pass unharmed to the next context-aware handler in the chain.
customChain = customChain.Append(apollo.Wrap(func(next http.Handler) http.Handler { return http.StripPrefix(api.Prefix, next) }))
}
// parse the target URL
target, err := url.Parse(api.TargetURL)
if err != nil {
logger.WithError(err).Fatalf("could not parse url : %v", api.TargetURL)
}
// parametrise the reverse proxy function
reverseProxyFunc := httputil.NewSingleHostReverseProxy(target)
// set a logger for the reverseProxyFunc, which only logs errors.
// Derive this logger from logrus.
revProxyLogger := logger.WriterLevel(logrus.ErrorLevel)
reverseProxyFunc.ErrorLog = log.New(revProxyLogger, "reverseproxy--", 0)
// wrap the reverse proxy into a handler (so it implements to apollo.Handler) and append it to the chain
mux.Handle(api.Prefix, customChain.Then(reverseProxyHandler(reverseProxyFunc)))
}
return mux
}