DEV Community

Cover image for DevPill #9 - IP Based Rate Limiter Middleware for native Go REST API
Raul Paes Silva
Raul Paes Silva

Posted on

DevPill #9 - IP Based Rate Limiter Middleware for native Go REST API

1. Function to get the right remote address ip

Usually we use the remoteAddr() function inside the type http.Request to get the remote address. But in cloud solutions it might not return the correct reference.
Use the function below to catch the right address:

func GetClientIP(r *http.Request) string {
    // if app in on Kubernetes ingress or other cloud services r.RemoteAddr might not work
    //get the client ip using the header X-Forwarded-For
    forwarded := r.Header.Get("X-Forwarded-For")
    if forwarded != "" {
        parts := strings.Split(forwarded, ",")
        return strings.TrimSpace(parts[0])
    }

    // Fallback to RemoteAddr
    ip, _, _ := net.SplitHostPort(r.RemoteAddr)
    return ip
}
Enter fullscreen mode Exit fullscreen mode

2. Getting the specific limiter for the received IP

Use a map of visitors to store every new IP as a key to a pointer of rate.Limiter (package "golang.org/x/time/rate").
We'll also use sync.Mutex to avoid race conditions.
The function below check if there's a limiter for the received ip, if not, it creates a new limiter allowing maximum 1 request per second and burst of 3 (meaning it will allow 3 requests at once for the "first" interaction ).

var visitors = make(map[string]*rate.Limiter)
var mu sync.Mutex

func getLimiter(ip string) *rate.Limiter {
    mu.Lock()
    defer mu.Unlock()

    limiter, exists := visitors[ip]
    if !exists {
        limiter = rate.NewLimiter(1, 3) // 1 req/s, 3 burst
        visitors[ip] = limiter
    }

    return limiter
}
Enter fullscreen mode Exit fullscreen mode

3. Middleware function

Get the client ip, request for its limiter, check if new request is allowed, if not returns "429 - Too many requests".

func IPBasedRateLimiter(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ip := GetClientIP(r)

        limiter := getLimiter(ip)

        if !limiter.Allow() {
            http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
            return
        }

        next.ServeHTTP(w, r)
    })
Enter fullscreen mode Exit fullscreen mode

}

Top comments (0)