/
cachehandler.go
149 lines (135 loc) · 3.39 KB
/
cachehandler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Package cachehandler provides net/http middleware that caches HTTP responses.
// Inspired by go-chi/stampede. https://github.com/go-chi/stampede
// The HTTP response is stored in the cache and calls to handlers with the same key are merged.
package cachehandler
import (
"bytes"
"net/http"
"sync"
"time"
"github.com/felixge/httpsnoop"
expirablecache "github.com/go-pkgz/expirable-cache"
"golang.org/x/sync/singleflight"
)
// BasicKeyFunc returns a KeyFunc that uses http method and url as key.
func BasicKeyFunc() KeyFunc {
return func(w http.ResponseWriter, r *http.Request) (string, bool) {
return r.Method + r.URL.String(), true
}
}
// KeyFunc is type that returns key for cache.
// If the key func returns false, the Middleware does not call next handler chains.
type KeyFunc func(w http.ResponseWriter, r *http.Request) (string, bool)
type response struct {
header http.Header
statusCode int
body []byte
}
// Middleware describes cachehandler
type Middleware struct {
ttl time.Duration
keyFn KeyFunc
cache expirablecache.Cache
pool sync.Pool
group singleflight.Group
}
// CacheStats describes cache statistics.
type CacheStats struct {
Hits, Misses int // cache effectiveness
Added, Evicted int // number of added and evicted records
}
func (m *Middleware) Stats() CacheStats {
stats := m.cache.Stat()
return CacheStats{
Hits: stats.Hits,
Misses: stats.Misses,
Added: stats.Added,
Evicted: stats.Evicted,
}
}
func (m *Middleware) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key, ok := m.keyFn(w, r)
if !ok {
return
}
if v, ok := m.cache.Get(key); ok {
resp, ok := v.(*response)
if !ok {
goto MISS
}
header := w.Header()
for k, list := range resp.header {
for _, h := range list {
header.Set(k, h)
}
}
w.WriteHeader(resp.statusCode)
w.Write(resp.body)
return
}
MISS:
var (
header http.Header
wroteHeader bool
status int
)
buf := m.pool.Get().(*bytes.Buffer)
defer m.pool.Put(buf)
m.group.Do(key, func() (interface{}, error) {
next.ServeHTTP(httpsnoop.Wrap(w, httpsnoop.Hooks{
WriteHeader: func(whf httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(code int) {
whf(code)
if !wroteHeader {
status = code
wroteHeader = true
}
}
},
Write: func(wf httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(b []byte) (int, error) {
n, err := wf(b)
buf.Write(b)
return n, err
}
},
Header: func(hf httpsnoop.HeaderFunc) httpsnoop.HeaderFunc {
return func() http.Header {
h := hf()
header = h
return h
}
},
}), r)
return nil, nil
})
if status == 0 {
status = http.StatusOK
}
resp := &response{
header: header,
statusCode: status,
body: buf.Bytes(),
}
m.cache.Set(key, resp, m.ttl)
})
}
// NewMiddleware returns net/http middleware that caches http responses.
func NewMiddleware(max int, ttl time.Duration, keyFn KeyFunc) *Middleware {
cache, err := expirablecache.NewCache(expirablecache.MaxKeys(max), expirablecache.TTL(ttl))
if err != nil {
panic(err) // never happen
}
return &Middleware{
ttl: ttl,
keyFn: keyFn,
cache: cache,
pool: sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
},
group: singleflight.Group{},
}
}