DEV Community

Francis kinyuru
Francis kinyuru

Posted on

Rate limiting with bucket4j, redis, Postgres in SpringBoot

Redis rate limiting using Bucket4j, redis, postgres in springboot

Key Objective:

  1. What specific API will be limited? This could be a single API endpoint, or a group of related endpoints.
  2. How will the username be used to identify requests? This could be done by including the username in the request header or in the request body, or by using a token-based authentication system.
  3. What is the custom TPS for each user? This is the maximum number of requests that each user will be allowed to make per second.
  4. What will happen if a user exceeds their TPS limit? This could result in the request being rejected, or the user being temporarily banned from making requests.

Why Bucket4j

There are many reasons why you might use Bucket4j for rate limiting. Here are a few of the most important ones:

  • It is a mature and well-tested library. Bucket4j has been around for many years and has been used by a large number of projects. This means that it is a reliable and stable library that you can be confident in using.
  • It is easy to use. Bucket4j has a simple and straightforward API that makes it easy to implement rate limiting in your code. You don't need to be an expert in rate limiting to get started with Bucket4j.
  • It is flexible. Bucket4j allows you to configure rate limits in a variety of ways. You can specify the number of requests per second, minute, hour, or day. You can also specify burst limits, which allow for a certain number of requests to be made above the regular limit in a short period of time.
  • It is efficient. Bucket4j is a very efficient library. It uses a token bucket algorithm to track requests, which is a very efficient way to do rate limiting. This means that Bucket4j will not have a significant impact on the performance of your application.

Overall, Bucket4j is a great choice for rate limiting in Java. It is a mature, well-tested, easy-to-use, flexible, and efficient library.

Here are some additional benefits of using Bucket4j:

  • It can protect your API from DDoS attacks. By limiting the number of requests that can be made to your API, you can make it more difficult for attackers to overwhelm your server with requests.
  • It can ensure fair resource allocation. By limiting the number of requests that each client can make, you can ensure that all clients have fair access to your API's resources.
  • It can improve the performance of your API. By limiting the number of requests that can be made, you can reduce the load on your server, which can improve the performance of your API.

If you are looking for a rate limiting library for your Java application, I highly recommend Bucket4j. It is a great choice for protecting your API from DDoS attacks, ensuring fair resource allocation, and improving the performance of your API.

Lets get now to the actual implementation now on spring boot.

Below are the dependencies required for the implementation.

<dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.redisson/redisson-spring-boot-starter -->
        <dependency>
            <groupId>org.redisson</groupId>
            <artifactId>redisson-spring-boot-starter</artifactId>
            <version>3.23.1</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/com.giffing.bucket4j.spring.boot.starter/bucket4j-spring-boot-starter -->
        <dependency>
            <groupId>com.giffing.bucket4j.spring.boot.starter</groupId>
            <artifactId>bucket4j-spring-boot-starter</artifactId>
            <version>0.9.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-validation</artifactId>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-jpa</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
Enter fullscreen mode Exit fullscreen mode

The implementation is maven based and using java 17 for the implementation.

Configuration under the config package

  1. RedisConfig
package Rateimiting.config;

import com.giffing.bucket4j.spring.boot.starter.config.cache.SyncCacheResolver;
import com.giffing.bucket4j.spring.boot.starter.config.cache.jcache.JCacheCacheResolver;
import io.github.bucket4j.distributed.proxy.ProxyManager;
import io.github.bucket4j.grid.jcache.JCacheProxyManager;
import org.redisson.config.Config;
import org.redisson.jcache.configuration.RedissonConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;

import javax.cache.CacheManager;
import javax.cache.Caching;

@Configuration
public class RedisConfig {

    @Bean
    public Config config(){
        Config config=new Config();
        config.useSingleServer().setAddress("redis://localhost:6379");
        return config;
    }
    @Bean
    public CacheManager cacheManager(Config config){
        CacheManager manager = Caching.getCachingProvider().getCacheManager();
        manager.createCache("cache", RedissonConfiguration.fromConfig(config));
        return manager;
    }
    @Bean
    ProxyManager<String> proxyManager(CacheManager cacheManager){
        return new JCacheProxyManager<>(cacheManager.getCache("cache"));
    }

    @Bean
    @Primary
    public SyncCacheResolver bucket4jCacheResolver(CacheManager cacheManager){
        return new JCacheCacheResolver(cacheManager);
    }
}

Enter fullscreen mode Exit fullscreen mode

Config bean

This bean creates a configuration object for Redis. This object specifies the address of the Redis server

CacheManager bean

This bean creates a cache manager that uses Redis as its backing store. This cache manager will be used to store the rate limits.

ProxyManager bean

This bean creates a proxy manager that can be used to access the rate limits in the cache. This proxy manager makes it easy to use the rate limits in your code.

SyncCacheResolver Bean

This bean creates a cache resolver that uses the cache manager to resolve the rate limits. This cache resolver is used by Bucket4j to determine the current rate limit for a given request.

@primary annotation:

This annotation indicates that this is the default cache resolver that should be used by Bucket4j.

  1. RateLimitConfig
package Rateimiting.config;

import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.Refill;
import io.github.bucket4j.distributed.proxy.ProxyManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;

import java.time.Duration;
import java.util.function.Supplier;

@Configuration
public class RateLimitConfig {
    @Autowired
    public ProxyManager buckets;

    public Bucket resolveBucket(String key,int tps){
        Supplier<BucketConfiguration> configurationSupplier = getConfigSupplier(key,tps);
        return buckets.builder().build(key,configurationSupplier);
    }
    private Supplier<BucketConfiguration> getConfigSupplier(String key, int tps){
        Refill refill=Refill.intervally(tps, Duration.ofSeconds(1));
        Bandwidth limit=Bandwidth.classic(tps,refill);

        return () -> (BucketConfiguration.builder()
                .addLimit(limit)
                .build());
    }
}

Enter fullscreen mode Exit fullscreen mode

resolveBucket() method:

This method takes two parameters: the key of the bucket and the number of requests per second (TPS) that the bucket should allow. The method first calls the getConfigSupplier() method to get a supplier of BucketConfiguration objects. The method then uses the supplier of BucketConfiguration objects to create a Bucket object. The Bucket object is used to rate limit requests.

getConfigSupplier() method:

This method creates a BucketConfiguration object that specifies the rate limit for the bucket. The rate limit is specified by the Bandwidth object. The Bandwidth object specifies the number of requests that the bucket can allow per second, and the refill rate. The refill rate specifies how often the bucket is refilled with tokens.

In my model package

model->request

package Rateimiting.model.request;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class Request {
    private String username;
    private String Message;
}

Enter fullscreen mode Exit fullscreen mode

This represents the sample request body i expect from a post request.

model Apiresponse

package Rateimiting.model;

import lombok.*;

@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class ApiResponse {
    private String message;
    private String responseCode;
    private String status;
}

Enter fullscreen mode Exit fullscreen mode

The expected api response

model TpsDb

package Rateimiting.model;

import jakarta.persistence.*;
import lombok.Data;

import java.io.Serializable;
import java.sql.Timestamp;

@Data
@Table(name="tbl_tps")
@Entity
public class TpsDb  implements Serializable {
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private long id;
    private String username;
    private String path;
    private int tps;
}

Enter fullscreen mode Exit fullscreen mode

This represents the tables for where the tps details are stored in the DB.

Repository

package Rateimiting.repository;

import Rateimiting.model.TpsDb;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

@Repository
public interface TpsDbRepository extends JpaRepository<TpsDb, Long> {
    TpsDb findByUsernameAndPath(String username, String path);
}

Enter fullscreen mode Exit fullscreen mode

Added the query tps details my username and path for post request.

Service

package Rateimiting.service;

import Rateimiting.config.RateLimitConfig;
import Rateimiting.model.ApiResponse;
import Rateimiting.model.TpsDb;
import Rateimiting.model.request.Request;
import Rateimiting.repository.TpsDbRepository;
import io.github.bucket4j.Bucket;
import org.redisson.api.RMapCache;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;

@Service
public class RateLimitService {
    int tps;
    @Autowired
    private RedissonClient redissonClient;
    private final TpsDbRepository tpsD;

    private final RateLimitConfig rateLimitConfig;
    @Autowired
    public RateLimitService(TpsDbRepository tpsD, RateLimitConfig rateLimitConfig) {
        this.tpsD = tpsD;
        this.rateLimitConfig = rateLimitConfig;
    }

    public ResponseEntity<?> addInfo(Request request, String path) {
        String username = request.getUsername();
        // Check if the TpsDb is cached in Redis
        RMapCache<String, TpsDb> cache = redissonClient.getMapCache("tpsDbCache");
        TpsDb tpsDb = cache.get(username + "-" + path);

        if (tpsDb == null) {
            tpsDb = tpsD.findByUsernameAndPath(path, username);

            if (tpsDb == null) {
                tpsDb = new TpsDb();
                tpsDb.setUsername(username);
                tpsDb.setPath(path);
                tpsDb.setTps(10); // Default TPS value if not found in the database
            }
            cache.put(username + "-" + path, tpsDb);
        }


        int tps = tpsDb.getTps();
        Bucket bucket = rateLimitConfig.resolveBucket(username, tps);
        if (bucket.tryConsume(1)) {
            return ResponseEntity.status(200).body(new ApiResponse("Request Success for user " + username, "4000", "success"));
        } else {
            return ResponseEntity.status(429).body(new ApiResponse("Request failed for user " + username, "4003", "failed"));
        }
    }
}

Enter fullscreen mode Exit fullscreen mode

addInfo() method:

This method takes a request object and a path as input and returns a response indicating whether the request was successful or not. The method first checks if the TpsDb object is cached in Redis. If the TpsDb object is not cached, the method retrieves it from the database. If the TpsDb object is not found in the database, the method creates a new TpsDb object with the default TPS value. The method then creates a Bucket object using the user ID and the TPS value from the TpsDb object. The Bucket object is used to rate limit the request. If the request is successful, the method returns a 200 status code and a success message. Otherwise, the method returns a 429 status code and an error message.

Finally,

Controller.

To create test post request

package Rateimiting.controller;

import Rateimiting.model.request.Request;
import Rateimiting.service.RateLimitService;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.HandlerMapping;

@RestController
@RequestMapping("/v1/")
public class RateLimitController {
    @Autowired
    private HttpServletRequest requests;

    private final RateLimitService rateLimitService;
    @Autowired
    public RateLimitController(RateLimitService rateLimitService) {
        this.rateLimitService = rateLimitService;
    }

    @PostMapping("/rate")
    public ResponseEntity<?> addInfo(@RequestBody Request request){
        String uri = requests.getRequestURI();
        String path = (String) requests.getAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE);
    return rateLimitService.addInfo(request, path);
    }
}

Enter fullscreen mode Exit fullscreen mode

@PostMapping("/rate") annotation:

This annotation marks the addInfo() method as a POST endpoint that can be used to add a new request to the rate limiting system.

@RequestBody Request request parameter:

This parameter specifies that the addInfo() method takes a request object as input.

String uri = requests.getRequestURI() statement:

This statement gets the URI of the request. The URI is used to get the path of the resource being requested.

String path = (String) requests.getAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE) statement:

This statement gets the path of the resource being requested from the HandlerMapping object.
return rateLimitService.addInfo(request, path); statement: This statement calls the addInfo() method on the RateLimitService class to add the request to the rate limiting system.

application.properties

spring.sql.init.platform=postgresql
spring.datasource.url=jdbc:postgresql://localhost:5432/dbname
spring.datasource.username=username
spring.datasource.password=password
spring.datasource.driver-class-name=org.postgresql.Driver
spring.jpa.database-platform=org.hibernate.dialect.PostgreSQLDialect
# Keep the connection alive if idle for a long time (needed in production)
spring.datasource.testWhileIdle=true
spring.jpa.hibernate.ddl-auto=update

Enter fullscreen mode Exit fullscreen mode

You can access the code on my git repo here

Top comments (0)