1. Function to get the right remote address ip
Usually we use the remoteAddr() function inside the type http.Request to get the remote address. But in cloud solutions it might not return the correct reference.
Use the function below to catch the right address:
func GetClientIP(r *http.Request) string {
// if app in on Kubernetes ingress or other cloud services r.RemoteAddr might not work
//get the client ip using the header X-Forwarded-For
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
parts := strings.Split(forwarded, ",")
return strings.TrimSpace(parts[0])
}
// Fallback to RemoteAddr
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
return ip
}
2. Getting the specific limiter for the received IP
Use a map of visitors to store every new IP as a key to a pointer of rate.Limiter (package "golang.org/x/time/rate").
We'll also use sync.Mutex to avoid race conditions.
The function below check if there's a limiter for the received ip, if not, it creates a new limiter allowing maximum 1 request per second and burst of 3 (meaning it will allow 3 requests at once for the "first" interaction ).
var visitors = make(map[string]*rate.Limiter)
var mu sync.Mutex
func getLimiter(ip string) *rate.Limiter {
mu.Lock()
defer mu.Unlock()
limiter, exists := visitors[ip]
if !exists {
limiter = rate.NewLimiter(1, 3) // 1 req/s, 3 burst
visitors[ip] = limiter
}
return limiter
}
3. Middleware function
Get the client ip, request for its limiter, check if new request is allowed, if not returns "429 - Too many requests".
func IPBasedRateLimiter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := GetClientIP(r)
limiter := getLimiter(ip)
if !limiter.Allow() {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
Top comments (0)