Let's explore rate limiting together in go. If you are here, you probably already know what rate limiting is and why it is important but spare me half a minute to do a brief intro for the guys at the back.
Rate limiting is a technique used to restrict the number of requests a client can make to a server within a specific time window.
Why do I need a rate limiter you ask? to prevent DDoS attacks. What is a DDoS attack you ask? DDoS is acronym for Distributed Denial of Service. When your server is suddenly inundated with an absurb amount of internet traffic, you're are probably experiencing a DDoS attack.
Now that we have gotten definitions out of the way, let's talk about some common rate limiting algorithms. There are quite a number of rate limiting algorithms with varying degrees of complexity, resource usage and trade offs. I'll be exploring 3 of the more common rate limiting strategies:
- Fixed window counter
- Sliding window counter
- Token bucket algorithm
In this expedition, I'll be focusing on the fixed window counter strategy.
Fixed Window Counter
This is the simplest and most common algorithm used for rate limiting. It is essentially acts like a timer for incoming traffic. It is an excellent chioce for low risk APIs where slight traffic spikes aren't a problem. A drawback to this algorithm is the burstiness that can occur immediately the window resets - which in plain english means an attacker can basically make 2 times the request allowed in one window a second before and immediately after the window resets. If your API allows 10 requests per minute, an attacker can wait for the last second of that minute make 10 requests then and make another immediately after the counter resets - that's 20 requests in 2 seconds.
Let's write a go implementation of this strategy
Start a go project and create your entry file. In your entry file create a http server and with an example handler
package main
import (
"fmt"
"log"
"net/http"
)
const DEFAULT_PORT = 3000
func hello(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message": "hello"}`))
}
func main() {
mux := http.DefaultServeMux
mux.Handle("/", hello)
serv := &http.Server{
Addr: fmt.Sprintf(":%d", DEFAULT_PORT),
Handler: mux,
}
if err := serv.ListenAndServe(); err != nil {
log.Fatal(err)
}
}
We should now be able to run our server, visit http://localhost:3000 and get a "hello world" response in our client. Right now, we can send an unrestricted number of http requests to the server. Let's add rate limiting to fix that.
In the same directory create a ratelimit package. In your limiter.go file let's write our rate limiting logic
package ratelimit
import (
"fmt"
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"time"
)
type Counter struct {
value int // tracks the number of visits
windowEnd time.Time // tracks expiration time
}
type RateLimiter struct {
windowSize time.Duration // window duration
counters map[string]*Counter // holds counter information per IP
mu sync.Mutex
maxRequest int // maximum number requests
}
func newRateLimiter(maxRequest int, windowSize time.Duration) *RateLimiter {
// initialise rate limiter
limiter := &RateLimiter{
maxRequest: maxRequest,
windowSize: windowSize,
counters: make(map[string]*Counter),
}
// start a goroutine to cleanup expired counters
go func() {
t := time.NewTicker(time.Minute)
defer t.Stop()
// do cleanup every minute
for range t.C {
limiter.cleanup()
}
}()
return limiter
}
func (r *RateLimiter) cleanup() {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
// remove expired counters
for key, c := range r.counters {
if now.After(c.windowEnd) {
delete(r.counters, key)
}
}
}
type LimiterInfo struct {
Allowed bool
Remaining int
Reset time.Time
}
func (r *RateLimiter) allow(key string) (bool, LimiterInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
// check if counter exists
c, ok := r.counters[key]
// if it doesn't create a counter and add to the counters map
if !ok || now.After(c.windowEnd) {
c := Counter{
value: 1,
windowEnd: now.Add(r.windowSize),
}
r.counters[key] = &c
return true, LimiterInfo{
Allowed: true,
Remaining: r.maxRequest - 1,
Reset: c.windowEnd,
}
}
// check if maximum amount of requests exceeded
if c.value < r.maxRequest {
// increment value count
c.value++
return true, LimiterInfo{
Allowed: true,
Remaining: c.value - r.maxRequest,
Reset: c.windowEnd,
}
}
return false, LimiterInfo{
Allowed: false,
Remaining: 0,
Reset: c.windowEnd,
}
}
// helper function to help get client IP
func getRequestIP(r *http.Request) string {
if ips := r.Header.Get("X-Forwarded-For"); ips != "" {
parts := strings.Split(ips, ",")
return parts[0]
}
if ip := r.Header.Get("x-Real-IP"); ip != "" {
return ip
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
// create middleware function
func New(maxRequest int, windowSize time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
// create limiter in closure
limiter := newRateLimiter(maxRequest, windowSize)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// get IP address
key := getRequestIP(r)
// check if window is valid
allowed, info := limiter.allow(key)
// set rate limiter headers
w.Header().Set("X-RateLimit-Limit", fmt.Sprint(limiter.maxRequest))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprint(info.Remaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprint(info.Reset.Unix()))
if !allowed {
// set retry after header
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Until(info.Reset).Seconds()))
// set content type header
w.Header().Set("Content-Type", "application/json")
// set status code
w.WriteHeader(http.StatusTooManyRequests)
// send json response
json.NewEncoder(w).Encode(map[string]string{
"detail": "rate limit exceeded",
"retry_after": info.Reset.Format(time.RFC1123),
})
return
}
next.ServeHTTP(w, r)
})
}
}
Now let's adjust our main.go file to use the rate limiter
import (
// imports remain the same
"time"
ratelimit "path-to-ratelimit-package"
)
func main() {
mux := http.DefaultServeMux
// rate limiter allows 2 requests per minute
limiter := ratelimit.New(2, time.Minute)
mux.Handle("/", limiter(http.HandlerFunc(hello)))
serv := &http.Server{
Addr: fmt.Sprintf(":%d", DEFAULT_PORT),
Handler: mux,
}
if err := serv.ListenAndServe(); err != nil {
log.Fatal(err)
}
}
If you try making a request to the appropriate endpoint more than 2 times in one minute, you'll get a 429 error status code.
Top comments (0)