DEV Community

IamKhan
IamKhan

Posted on

Exploring Rate Limiting in Go

Let's explore rate limiting together in go. If you are here, you probably already know what rate limiting is and why it is important but spare me half a minute to do a brief intro for the guys at the back.

Rate limiting is a technique used to restrict the number of requests a client can make to a server within a specific time window.

Why do I need a rate limiter you ask? to prevent DDoS attacks. What is a DDoS attack you ask? DDoS is acronym for Distributed Denial of Service. When your server is suddenly inundated with an absurb amount of internet traffic, you're are probably experiencing a DDoS attack.

Now that we have gotten definitions out of the way, let's talk about some common rate limiting algorithms. There are quite a number of rate limiting algorithms with varying degrees of complexity, resource usage and trade offs. I'll be exploring 3 of the more common rate limiting strategies:

  • Fixed window counter
  • Sliding window counter
  • Token bucket algorithm

In this expedition, I'll be focusing on the fixed window counter strategy.

Fixed Window Counter

This is the simplest and most common algorithm used for rate limiting. It is essentially acts like a timer for incoming traffic. It is an excellent chioce for low risk APIs where slight traffic spikes aren't a problem. A drawback to this algorithm is the burstiness that can occur immediately the window resets - which in plain english means an attacker can basically make 2 times the request allowed in one window a second before and immediately after the window resets. If your API allows 10 requests per minute, an attacker can wait for the last second of that minute make 10 requests then and make another immediately after the counter resets - that's 20 requests in 2 seconds.

Let's write a go implementation of this strategy

Start a go project and create your entry file. In your entry file create a http server and with an example handler

package main

import (
    "fmt"
    "log"
    "net/http"
)

const DEFAULT_PORT = 3000

func hello(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(http.StatusOK)
    w.Write([]byte(`{"message": "hello"}`))
}

func main() {
    mux := http.DefaultServeMux
    mux.Handle("/", hello)

    serv := &http.Server{
        Addr: fmt.Sprintf(":%d", DEFAULT_PORT),
        Handler: mux,
    }

    if err := serv.ListenAndServe(); err != nil {
        log.Fatal(err)
    }
}
Enter fullscreen mode Exit fullscreen mode

We should now be able to run our server, visit http://localhost:3000 and get a "hello world" response in our client. Right now, we can send an unrestricted number of http requests to the server. Let's add rate limiting to fix that.

In the same directory create a ratelimit package. In your limiter.go file let's write our rate limiting logic

package ratelimit

import (
    "fmt"
    "encoding/json"
    "net"
    "net/http"
    "strings"
    "sync"
    "time"
)

type Counter struct {
    value     int // tracks the number of visits
    windowEnd time.Time // tracks expiration time
}

type RateLimiter struct {
    windowSize time.Duration // window duration
    counters   map[string]*Counter // holds counter information per IP
    mu         sync.Mutex
    maxRequest int // maximum number requests
}

func newRateLimiter(maxRequest int, windowSize time.Duration) *RateLimiter {
    // initialise rate limiter
    limiter := &RateLimiter{
        maxRequest: maxRequest,
        windowSize: windowSize,
        counters:   make(map[string]*Counter),
    }

    // start a goroutine to cleanup expired counters
    go func() {
        t := time.NewTicker(time.Minute)
        defer t.Stop()
        // do cleanup every minute
        for range t.C {
            limiter.cleanup()
        }
    }()

    return limiter
}

func (r *RateLimiter) cleanup() {
    r.mu.Lock()
    defer r.mu.Unlock()
    now := time.Now()

    // remove expired counters
    for key, c := range r.counters {
        if now.After(c.windowEnd) {
            delete(r.counters, key)
        }
    }
}

type LimiterInfo struct {
    Allowed   bool
    Remaining int
    Reset     time.Time
}

func (r *RateLimiter) allow(key string) (bool, LimiterInfo) {
    r.mu.Lock()
    defer r.mu.Unlock()
    now := time.Now()

    // check if counter exists
    c, ok := r.counters[key]

    // if it doesn't create a counter and add to the counters map
    if !ok || now.After(c.windowEnd) {
        c := Counter{
            value:     1,
            windowEnd: now.Add(r.windowSize),
        }

        r.counters[key] = &c

        return true, LimiterInfo{
            Allowed:   true,
            Remaining: r.maxRequest - 1,
            Reset:     c.windowEnd,
        }
    }

    // check if maximum amount of requests exceeded
    if c.value < r.maxRequest {
        // increment value count
        c.value++

        return true, LimiterInfo{
            Allowed:   true,
            Remaining: c.value - r.maxRequest,
            Reset:     c.windowEnd,
        }
    }

    return false, LimiterInfo{
        Allowed:   false,
        Remaining: 0,
        Reset:     c.windowEnd,
    }
}

// helper function to help get client IP
func getRequestIP(r *http.Request) string {
    if ips := r.Header.Get("X-Forwarded-For"); ips != "" {
        parts := strings.Split(ips, ",")
        return parts[0]
    }

    if ip := r.Header.Get("x-Real-IP"); ip != "" {
        return ip
    }

    ip, _, err := net.SplitHostPort(r.RemoteAddr)
    if err != nil {
        return r.RemoteAddr
    }

    return ip
}

// create middleware function
func New(maxRequest int, windowSize time.Duration) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        // create limiter in closure
        limiter := newRateLimiter(maxRequest, windowSize)

        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // get IP address
            key := getRequestIP(r)

            // check if window is valid
            allowed, info := limiter.allow(key)

            // set rate limiter headers
            w.Header().Set("X-RateLimit-Limit", fmt.Sprint(limiter.maxRequest))
            w.Header().Set("X-RateLimit-Remaining", fmt.Sprint(info.Remaining))
            w.Header().Set("X-RateLimit-Reset", fmt.Sprint(info.Reset.Unix()))

            if !allowed {
                // set retry after header
                w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Until(info.Reset).Seconds()))

                // set content type header
                w.Header().Set("Content-Type", "application/json")

                // set status code
                w.WriteHeader(http.StatusTooManyRequests)

                // send json response
                json.NewEncoder(w).Encode(map[string]string{
                    "detail": "rate limit exceeded",
                    "retry_after": info.Reset.Format(time.RFC1123),
                })

                return
            }

            next.ServeHTTP(w, r)
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

Now let's adjust our main.go file to use the rate limiter

import (
    // imports remain the same
    "time"

    ratelimit "path-to-ratelimit-package"
)

func main() {
    mux := http.DefaultServeMux

    // rate limiter allows 2 requests per minute 
    limiter := ratelimit.New(2, time.Minute)
    mux.Handle("/", limiter(http.HandlerFunc(hello)))

    serv := &http.Server{
        Addr: fmt.Sprintf(":%d", DEFAULT_PORT),
        Handler: mux,
    }

    if err := serv.ListenAndServe(); err != nil {
        log.Fatal(err)
    }
}
Enter fullscreen mode Exit fullscreen mode

If you try making a request to the appropriate endpoint more than 2 times in one minute, you'll get a 429 error status code.

Top comments (0)