-
Notifications
You must be signed in to change notification settings - Fork 61
/
ratelimiter.go
89 lines (78 loc) 路 2.21 KB
/
ratelimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Package ratelimiter provides basic rate limiting functionality as a with middeware.
package ratelimiter
import (
"errors"
"log"
"net"
"github.com/charmbracelet/wish"
"github.com/gliderlabs/ssh"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/time/rate"
)
// ErrRateLimitExceeded happens when the connection was denied due to the rate limit being exceeded.
var ErrRateLimitExceeded = errors.New("rate limit exceeded, please try again later")
// RateLimiter implementations should check if a given session is allowed to
// proceed or not, returning an error if they aren't.
// Its up to the implementation to handle what identifies an session as well
// as the implementation details of these limits.
type RateLimiter interface {
Allow(s ssh.Session) error
}
// Middleware provides a new rate limiting Middleware.
func Middleware(limiter RateLimiter) wish.Middleware {
return func(sh ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
if err := limiter.Allow(s); err != nil {
wish.Fatal(s, err)
return
}
sh(s)
}
}
}
// NewRateLimiter returns a new RateLimiter that allows events up to rate rate,
// permits bursts of at most burst tokens and keeps a cache of maxEntries
// limiters.
//
// Internally, it creates a LRU Cache of *rate.Limiter, in which the key is
// the remote IP address.
func NewRateLimiter(r rate.Limit, burst int, maxEntries int) RateLimiter {
if maxEntries <= 0 {
maxEntries = 1
}
// only possible error is if maxEntries is <= 0, which is prevented above.
cache, _ := lru.New[string, *rate.Limiter](maxEntries)
return &limiters{
rate: r,
burst: burst,
cache: cache,
}
}
type limiters struct {
cache *lru.Cache[string, *rate.Limiter]
rate rate.Limit
burst int
}
func (r *limiters) Allow(s ssh.Session) error {
var key string
switch addr := s.RemoteAddr().(type) {
case *net.TCPAddr:
key = addr.IP.String()
default:
key = addr.String()
}
var allowed bool
limiter, ok := r.cache.Get(key)
if ok {
allowed = limiter.Allow()
} else {
limiter := rate.NewLimiter(r.rate, r.burst)
allowed = limiter.Allow()
r.cache.Add(key, limiter)
}
log.Printf("rate limiter key: %q, allowed? %v", key, allowed)
if allowed {
return nil
}
return ErrRateLimitExceeded
}