DEV Community

Cover image for Sharding Databases with Spring Boot: Patterns, Pitfalls, and Failure Modes
Adam - The Developer
Adam - The Developer

Posted on

Sharding Databases with Spring Boot: Patterns, Pitfalls, and Failure Modes

Sharding is one of those topics that sounds intimidating until you actually understand what's happening under the hood. If you're a senior Java developer encountering sharding for the first time, or if you need a refresher on implementing it with Spring Boot, this guide covers everything from basic concepts to production ready implementations.

Table of Contents

What is Database Sharding?

Database sharding is a horizontal partitioning strategy where you split your data across multiple database instances (shards). Instead of having one massive database trying to handle all your traffic and data, you distribute the load across multiple smaller databases.

Think of it like this: instead of one gigantic library holding every book ever written, you have multiple libraries where books are distributed based on some criteria (maybe by author's last name, subject, or publication date).

Why Shard?

Before diving into implementation, let's talk about when you actually need sharding:

Vertical scaling hits a wall. You can only add so much RAM, CPU, and faster disks to a single database server before you hit physical or cost limitations.

Query performance degrades. Even with proper indexing, queries slow down as your dataset grows. A query on a 10TB table is inherently slower than the same query on a 100GB table.

Write throughput becomes a bottleneck. A single database can only handle so many writes per second. When your application needs to process millions of writes per minute, one database won't cut it.

High availability requirements. Sharding can improve availability because a failure in one shard doesn't take down your entire system.

Core Sharding Concepts

The Shard Key

The shard key is the field you use to determine which shard holds a particular piece of data. This is absolutely critical to get right because changing it later is extremely painful.

Common shard key choices:

  • User ID
  • Tenant ID (for multi tenant applications)
  • Geographic region
  • Timestamp ranges (for time series data)

The shard key should have these properties:

  • High cardinality: Many unique values to distribute data evenly
  • Query friendly: Your most common queries should be able to target a single shard
  • Stable: The value shouldn't change frequently (or at all)

Sharding Strategies

There are several ways to distribute data across shards:

1. Range Based Sharding

You divide data based on ranges of the shard key.

Shard 1: user_id 1 to 1,000,000
Shard 2: user_id 1,000,001 to 2,000,000
Shard 3: user_id 2,000,001 to 3,000,000
Enter fullscreen mode Exit fullscreen mode

Pros: Simple to implement, range queries are efficient
Cons: Can lead to hotspots if data isn't uniformly distributed

2. Hash Based Sharding

You apply a hash function to the shard key and use the result to determine the shard.

public int getShard(Long userId, int numShards) {
    return Math.abs(userId.hashCode()) % numShards;
}
Enter fullscreen mode Exit fullscreen mode

Pros: Excellent data distribution
Cons: Range queries require hitting all shards

3. Directory Based Sharding

You maintain a lookup table that maps shard keys to specific shards.

user_id -> shard mapping stored in a lookup service
1234 -> shard_2
5678 -> shard_1
9012 -> shard_3
Enter fullscreen mode Exit fullscreen mode

Pros: Flexible, can rebalance without changing keys
Cons: Lookup table becomes a single point of failure, adds latency

4. Geographic Sharding

Data is sharded based on geographic location.

Users in North America -> shard_us
Users in Europe -> shard_eu
Users in Asia -> shard_asia
Enter fullscreen mode Exit fullscreen mode

Pros: Reduces latency for geo distributed users, helps with data residency compliance
Cons: Can lead to uneven distribution

Hash Function Deep Dive

Let's get into the details of how hash based sharding actually works in Java.

Choosing a Hash Function

You need a hash function that:

  • Distributes keys uniformly across shards
  • Is fast to compute
  • Has minimal collisions
  • Is deterministic (same input always gives same output)

Common choices:

MD5: Fast and provides good distribution

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.math.BigInteger;

public class ShardingUtils {

    public static int md5HashShard(String key, int numShards) {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            byte[] messageDigest = md.digest(key.getBytes());
            BigInteger bigInt = new BigInteger(1, messageDigest);
            return Math.abs(bigInt.intValue()) % numShards;
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("MD5 algorithm not found", e);
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

MurmurHash: Faster than MD5, excellent distribution

import com.google.common.hash.Hashing;
import java.nio.charset.StandardCharsets;

public class ShardingUtils {

    public static int murmurHashShard(String key, int numShards) {
        int hash = Hashing.murmur3_32_fixed()
            .hashString(key, StandardCharsets.UTF_8)
            .asInt();
        return Math.abs(hash) % numShards;
    }
}
Enter fullscreen mode Exit fullscreen mode

CRC32: Very fast, good enough for most use cases

import java.util.zip.CRC32;
import java.nio.charset.StandardCharsets;

public class ShardingUtils {

    public static int crc32HashShard(String key, int numShards) {
        CRC32 crc = new CRC32();
        crc.update(key.getBytes(StandardCharsets.UTF_8));
        return (int) (Math.abs(crc.getValue()) % numShards);
    }
}
Enter fullscreen mode Exit fullscreen mode

The Modulo Problem

The simple modulo approach has a major flaw: when you add or remove shards, almost all keys get reassigned to different shards.

// With 3 shards
hash(123) % 3 = 0  // Goes to shard 0

// Add a shard, now 4 shards
hash(123) % 4 = 3  // Now goes to shard 3!
Enter fullscreen mode Exit fullscreen mode

This means massive data migration when scaling.

Consistent Hashing to the Rescue

Consistent hashing solves the resharding problem by minimizing the number of keys that need to move when shards are added or removed.

Here's how it works conceptually:

  1. Imagine a ring (circle) with values from 0 to 2^32
  2. Each shard is assigned multiple points on this ring (virtual nodes)
  3. Each key is hashed to a point on the ring
  4. The key belongs to the first shard found moving clockwise from the key's position
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;

public class ConsistentHash<T> {

    private final int virtualNodes;
    private final TreeMap<Long, T> ring;
    private final MessageDigest md;

    public ConsistentHash(int virtualNodes) {
        this.virtualNodes = virtualNodes;
        this.ring = new TreeMap<>();
        try {
            this.md = MessageDigest.getInstance("MD5");
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("MD5 algorithm not found", e);
        }
    }

    public void addNode(T node) {
        for (int i = 0; i < virtualNodes; i++) {
            String virtualKey = node.toString() + ":" + i;
            long hash = hash(virtualKey);
            ring.put(hash, node);
        }
    }

    public void removeNode(T node) {
        for (int i = 0; i < virtualNodes; i++) {
            String virtualKey = node.toString() + ":" + i;
            long hash = hash(virtualKey);
            ring.remove(hash);
        }
    }

    public T getNode(String key) {
        if (ring.isEmpty()) {
            return null;
        }

        long hash = hash(key);

        // Find first node with hash >= key hash
        Map.Entry<Long, T> entry = ring.ceilingEntry(hash);

        // If not found, wrap around to first node
        if (entry == null) {
            entry = ring.firstEntry();
        }

        return entry.getValue();
    }

    private long hash(String key) {
        md.reset();
        md.update(key.getBytes());
        byte[] digest = md.digest();

        // Take first 8 bytes to create a long
        long hash = 0;
        for (int i = 0; i < 8; i++) {
            hash = (hash << 8) | (digest[i] & 0xFF);
        }
        return hash;
    }

    public int size() {
        return ring.size() / virtualNodes;
    }
}

// Usage
ConsistentHash<String> ch = new ConsistentHash<>(150);
ch.addNode("shard_1");
ch.addNode("shard_2");
ch.addNode("shard_3");

String shard = ch.getNode("user_12345");
System.out.println("Key belongs to: " + shard);
Enter fullscreen mode Exit fullscreen mode

With consistent hashing, adding or removing a shard only affects approximately 1/N of the keys (where N is the number of shards).

Spring Boot Implementation Patterns

Setting Up Multiple DataSources

First, let's configure multiple datasources in Spring Boot:

// application.yml
spring:
  datasource:
    shard0:
      jdbc-url: jdbc:postgresql://db1.example.com:5432/shard_0
      username: app_user
      password: secret
      driver-class-name: org.postgresql.Driver
      hikari:
        maximum-pool-size: 10
        minimum-idle: 2
    shard1:
      jdbc-url: jdbc:postgresql://db2.example.com:5432/shard_1
      username: app_user
      password: secret
      driver-class-name: org.postgresql.Driver
      hikari:
        maximum-pool-size: 10
        minimum-idle: 2
    shard2:
      jdbc-url: jdbc:postgresql://db3.example.com:5432/shard_2
      username: app_user
      password: secret
      driver-class-name: org.postgresql.Driver
      hikari:
        maximum-pool-size: 10
        minimum-idle: 2
Enter fullscreen mode Exit fullscreen mode

DataSource Configuration

import com.zaxxer.hikari.HikariDataSource;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.sql.DataSource;
import java.util.HashMap;
import java.util.Map;

@Configuration
public class DataSourceConfig {

    @Bean
    @ConfigurationProperties("spring.datasource.shard0")
    public DataSource shard0DataSource() {
        return DataSourceBuilder.create()
            .type(HikariDataSource.class)
            .build();
    }

    @Bean
    @ConfigurationProperties("spring.datasource.shard1")
    public DataSource shard1DataSource() {
        return DataSourceBuilder.create()
            .type(HikariDataSource.class)
            .build();
    }

    @Bean
    @ConfigurationProperties("spring.datasource.shard2")
    public DataSource shard2DataSource() {
        return DataSourceBuilder.create()
            .type(HikariDataSource.class)
            .build();
    }

    @Bean
    public Map<String, DataSource> shardDataSources(
            DataSource shard0DataSource,
            DataSource shard1DataSource,
            DataSource shard2DataSource) {

        Map<String, DataSource> dataSources = new HashMap<>();
        dataSources.put("shard_0", shard0DataSource);
        dataSources.put("shard_1", shard1DataSource);
        dataSources.put("shard_2", shard2DataSource);
        return dataSources;
    }
}
Enter fullscreen mode Exit fullscreen mode

Shard Resolver Service

import org.springframework.stereotype.Service;

@Service
public class ShardResolver {

    private final ConsistentHash<String> consistentHash;

    public ShardResolver() {
        this.consistentHash = new ConsistentHash<>(150);
        consistentHash.addNode("shard_0");
        consistentHash.addNode("shard_1");
        consistentHash.addNode("shard_2");
    }

    public String resolveShardId(Long userId) {
        return consistentHash.getNode(userId.toString());
    }

    public String resolveShardId(String key) {
        return consistentHash.getNode(key);
    }
}
Enter fullscreen mode Exit fullscreen mode

Sharded JDBC Template

import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;
import java.util.Map;

@Component
public class ShardedJdbcTemplate {

    private final Map<String, JdbcTemplate> shardTemplates;
    private final ShardResolver shardResolver;

    public ShardedJdbcTemplate(
            Map<String, DataSource> shardDataSources,
            ShardResolver shardResolver) {

        this.shardResolver = shardResolver;
        this.shardTemplates = new HashMap<>();

        shardDataSources.forEach((shardId, dataSource) -> {
            this.shardTemplates.put(shardId, new JdbcTemplate(dataSource));
        });
    }

    public JdbcTemplate getTemplate(Long userId) {
        String shardId = shardResolver.resolveShardId(userId);
        return shardTemplates.get(shardId);
    }

    public JdbcTemplate getTemplate(String shardId) {
        return shardTemplates.get(shardId);
    }

    public Map<String, JdbcTemplate> getAllTemplates() {
        return new HashMap<>(shardTemplates);
    }
}
Enter fullscreen mode Exit fullscreen mode

Query Patterns and Challenges

Single Shard Queries

When your query includes the shard key, you can route directly to one shard.

@Service
public class UserService {

    private final ShardedJdbcTemplate shardedJdbcTemplate;

    public UserService(ShardedJdbcTemplate shardedJdbcTemplate) {
        this.shardedJdbcTemplate = shardedJdbcTemplate;
    }

    public User getUserById(Long userId) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(userId);

        return template.queryForObject(
            "SELECT id, email, name, created_at FROM users WHERE id = ?",
            new Object[]{userId},
            (rs, rowNum) -> new User(
                rs.getLong("id"),
                rs.getString("email"),
                rs.getString("name"),
                rs.getTimestamp("created_at")
            )
        );
    }

    public void createUser(User user) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());

        template.update(
            "INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
            user.getId(),
            user.getEmail(),
            user.getName(),
            new Timestamp(System.currentTimeMillis())
        );
    }
}
Enter fullscreen mode Exit fullscreen mode

Cross Shard Queries (Scatter Gather)

When your query doesn't include the shard key, you need to query all shards and merge results.

import org.springframework.stereotype.Service;
import java.util.List;
import java.util.ArrayList;
import java.util.concurrent.*;

@Service
public class UserService {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final ExecutorService executorService;

    public UserService(ShardedJdbcTemplate shardedJdbcTemplate) {
        this.shardedJdbcTemplate = shardedJdbcTemplate;
        this.executorService = Executors.newFixedThreadPool(10);
    }

    public List<User> findUsersByEmail(String email) {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
        List<Future<List<User>>> futures = new ArrayList<>();

        // Query all shards in parallel
        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            Future<List<User>> future = executorService.submit(() -> {
                return entry.getValue().query(
                    "SELECT id, email, name, created_at FROM users WHERE email = ?",
                    new Object[]{email},
                    (rs, rowNum) -> new User(
                        rs.getLong("id"),
                        rs.getString("email"),
                        rs.getString("name"),
                        rs.getTimestamp("created_at")
                    )
                );
            });
            futures.add(future);
        }

        // Collect results from all shards
        List<User> results = new ArrayList<>();
        for (Future<List<User>> future : futures) {
            try {
                results.addAll(future.get(5, TimeUnit.SECONDS));
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException("Error executing scatter gather query", e);
            }
        }

        return results;
    }
}
Enter fullscreen mode Exit fullscreen mode

Joins Across Shards

This is where sharding gets painful. Joins across shards are expensive and complex.

Solution 1: Denormalize
Store redundant data to avoid cross shard joins.

-- Instead of joining users and orders across shards
-- Store user info in the orders table
CREATE TABLE orders (
    id BIGINT,
    user_id BIGINT,
    user_name VARCHAR(255),  -- Denormalized
    user_email VARCHAR(255),  -- Denormalized
    total DECIMAL(10,2)
)
Enter fullscreen mode Exit fullscreen mode

Solution 2: Application Level Joins
Fetch from multiple shards and join in memory.

public class OrderWithUser {
    private Order order;
    private User user;

    // getters and setters
}

@Service
public class OrderService {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final UserService userService;

    public OrderWithUser getOrderWithUser(Long orderId, Long userId) {
        // Get order from its shard
        JdbcTemplate orderTemplate = shardedJdbcTemplate.getTemplate(orderId);
        Order order = orderTemplate.queryForObject(
            "SELECT * FROM orders WHERE id = ?",
            new Object[]{orderId},
            (rs, rowNum) -> mapOrder(rs)
        );

        // Get user from their shard
        User user = userService.getUserById(userId);

        OrderWithUser result = new OrderWithUser();
        result.setOrder(order);
        result.setUser(user);
        return result;
    }
}
Enter fullscreen mode Exit fullscreen mode

Solution 3: Shard by Relationship
Co locate related data on the same shard.

// Users and their orders always on the same shard
// Use userId as shard key for both tables
@Service
public class OrderService {

    public List<OrderWithUser> getUserOrders(Long userId) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(userId);

        // Both tables on the same shard, joins work normally
        return template.query(
            "SELECT u.*, o.* FROM users u " +
            "JOIN orders o ON u.id = o.user_id " +
            "WHERE u.id = ?",
            new Object[]{userId},
            (rs, rowNum) -> mapOrderWithUser(rs)
        );
    }
}
Enter fullscreen mode Exit fullscreen mode

Transactions Across Shards

Distributed transactions are complex and slow. ACID guarantees across shards typically require two phase commit (2PC).

Two Phase Commit with Spring

import org.springframework.stereotype.Service;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.*;

@Service
public class DistributedTransactionService {

    private final Map<String, PlatformTransactionManager> transactionManagers;
    private final ShardedJdbcTemplate shardedJdbcTemplate;

    public DistributedTransactionService(
            Map<String, PlatformTransactionManager> transactionManagers,
            ShardedJdbcTemplate shardedJdbcTemplate) {
        this.transactionManagers = transactionManagers;
        this.shardedJdbcTemplate = shardedJdbcTemplate;
    }

    public void executeTwoPhaseCommit(Map<String, Runnable> shardOperations) {
        Map<String, TransactionStatus> transactions = new HashMap<>();

        try {
            // Phase 1: Prepare (execute all operations)
            for (Map.Entry<String, Runnable> entry : shardOperations.entrySet()) {
                String shardId = entry.getKey();
                PlatformTransactionManager tm = transactionManagers.get(shardId);

                TransactionStatus status = tm.getTransaction(
                    new DefaultTransactionDefinition()
                );
                transactions.put(shardId, status);

                try {
                    entry.getValue().run();
                } catch (Exception e) {
                    throw new RuntimeException("Failed to execute on shard: " + shardId, e);
                }
            }

            // Phase 2: Commit all
            for (Map.Entry<String, TransactionStatus> entry : transactions.entrySet()) {
                String shardId = entry.getKey();
                PlatformTransactionManager tm = transactionManagers.get(shardId);
                tm.commit(entry.getValue());
            }

        } catch (Exception e) {
            // Rollback all prepared transactions
            for (Map.Entry<String, TransactionStatus> entry : transactions.entrySet()) {
                try {
                    String shardId = entry.getKey();
                    PlatformTransactionManager tm = transactionManagers.get(shardId);
                    tm.rollback(entry.getValue());
                } catch (Exception rollbackException) {
                    // Log rollback failure
                }
            }
            throw new RuntimeException("Distributed transaction failed", e);
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Problems with 2PC:

  • Slow (multiple network round trips)
  • Coordinator failure can leave system in limbo
  • Locks held for long duration

Better Approach: Saga Pattern

import org.springframework.stereotype.Service;
import java.util.*;

@Service
public class SagaTransactionService {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final SagaLogRepository sagaLogRepository;

    public void transferMoney(Long fromUserId, Long toUserId, BigDecimal amount) {
        String sagaId = UUID.randomUUID().toString();

        try {
            // Step 1: Debit from user
            JdbcTemplate fromTemplate = shardedJdbcTemplate.getTemplate(fromUserId);
            fromTemplate.update(
                "UPDATE accounts SET balance = balance - ? WHERE user_id = ?",
                amount, fromUserId
            );
            sagaLogRepository.logStep(sagaId, "debit_complete", fromUserId);

            // Step 2: Credit to user
            JdbcTemplate toTemplate = shardedJdbcTemplate.getTemplate(toUserId);
            toTemplate.update(
                "UPDATE accounts SET balance = balance + ? WHERE user_id = ?",
                amount, toUserId
            );
            sagaLogRepository.logStep(sagaId, "credit_complete", toUserId);

            sagaLogRepository.markComplete(sagaId);

        } catch (Exception e) {
            compensateTransfer(sagaId, fromUserId, toUserId, amount);
            throw new RuntimeException("Transfer failed, compensation executed", e);
        }
    }

    private void compensateTransfer(
            String sagaId, 
            Long fromUserId, 
            Long toUserId, 
            BigDecimal amount) {

        List<SagaStep> completedSteps = sagaLogRepository.getCompletedSteps(sagaId);

        for (SagaStep step : completedSteps) {
            if ("debit_complete".equals(step.getStepName())) {
                // Reverse the debit
                JdbcTemplate template = shardedJdbcTemplate.getTemplate(fromUserId);
                template.update(
                    "UPDATE accounts SET balance = balance + ? WHERE user_id = ?",
                    amount, fromUserId
                );
            }
        }

        sagaLogRepository.markCompensated(sagaId);
    }
}
Enter fullscreen mode Exit fullscreen mode

Monitoring and Observability with Spring Boot Actuator

import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tags;
import org.springframework.stereotype.Component;

import java.util.Map;

@Component
public class ShardMetricsCollector {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final MeterRegistry meterRegistry;

    public ShardMetricsCollector(
            ShardedJdbcTemplate shardedJdbcTemplate,
            MeterRegistry meterRegistry) {
        this.shardedJdbcTemplate = shardedJdbcTemplate;
        this.meterRegistry = meterRegistry;
    }

    @Scheduled(fixedRate = 60000) // Every minute
    public void collectShardMetrics() {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();

        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            String shardId = entry.getKey();
            JdbcTemplate template = entry.getValue();

            // Collect row count
            Long rowCount = template.queryForObject(
                "SELECT COUNT(*) FROM users",
                Long.class
            );

            meterRegistry.gauge(
                "shard.row.count",
                Tags.of("shard", shardId),
                rowCount != null ? rowCount : 0
            );

            // Collect table size
            Long tableSize = template.queryForObject(
                "SELECT pg_total_relation_size('users')",
                Long.class
            );

            meterRegistry.gauge(
                "shard.table.size.bytes",
                Tags.of("shard", shardId),
                tableSize != null ? tableSize : 0
            );
        }
    }

    public Map<String, ShardMetrics> getShardMetrics() {
        Map<String, ShardMetrics> metrics = new HashMap<>();
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();

        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            String shardId = entry.getKey();
            JdbcTemplate template = entry.getValue();

            ShardMetrics shardMetrics = new ShardMetrics();
            shardMetrics.setShardId(shardId);
            shardMetrics.setRowCount(
                template.queryForObject("SELECT COUNT(*) FROM users", Long.class)
            );

            metrics.put(shardId, shardMetrics);
        }

        return metrics;
    }
}

@Data
class ShardMetrics {
    private String shardId;
    private Long rowCount;
    private Long tableSizeBytes;
    private Double queriesPerSecond;
    private Double avgQueryTimeMs;
}
Enter fullscreen mode Exit fullscreen mode

Health Check Endpoint

import org.springframework.boot.actuate.health.Health;
import org.springframework.boot.actuate.health.HealthIndicator;
import org.springframework.stereotype.Component;

@Component
public class ShardHealthIndicator implements HealthIndicator {

    private final ShardedJdbcTemplate shardedJdbcTemplate;

    public ShardHealthIndicator(ShardedJdbcTemplate shardedJdbcTemplate) {
        this.shardedJdbcTemplate = shardedJdbcTemplate;
    }

    @Override
    public Health health() {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
        Map<String, String> shardStatus = new HashMap<>();
        boolean allHealthy = true;

        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            String shardId = entry.getKey();
            try {
                entry.getValue().queryForObject("SELECT 1", Integer.class);
                shardStatus.put(shardId, "UP");
            } catch (Exception e) {
                shardStatus.put(shardId, "DOWN");
                allHealthy = false;
            }
        }

        if (allHealthy) {
            return Health.up().withDetails(shardStatus).build();
        } else {
            return Health.down().withDetails(shardStatus).build();
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Rebalancing and Resharding

Eventually, you'll need to add more shards or rebalance data.

Adding New Shards

@Service
public class ShardManagementService {

    private final ConsistentHash<String> consistentHash;
    private final Map<String, DataSource> shardDataSources;
    private final ShardedJdbcTemplate shardedJdbcTemplate;

    public void addShard(String newShardId, DataSource newDataSource) {
        // Add to consistent hash ring
        consistentHash.addNode(newShardId);

        // Add to datasources
        shardDataSources.put(newShardId, newDataSource);

        // Migrate affected keys
        migrateKeysToNewShard(newShardId);
    }

    private void migrateKeysToNewShard(String newShardId) {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();

        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            String oldShardId = entry.getKey();

            if (oldShardId.equals(newShardId)) {
                continue;
            }

            JdbcTemplate oldTemplate = entry.getValue();

            // Get all users from old shard
            List<User> users = oldTemplate.query(
                "SELECT id, email, name, created_at FROM users",
                (rs, rowNum) -> new User(
                    rs.getLong("id"),
                    rs.getString("email"),
                    rs.getString("name"),
                    rs.getTimestamp("created_at")
                )
            );

            JdbcTemplate newTemplate = shardedJdbcTemplate.getTemplate(newShardId);

            for (User user : users) {
                // Check if user should be on new shard
                String targetShard = consistentHash.getNode(user.getId().toString());

                if (targetShard.equals(newShardId)) {
                    // Move to new shard
                    newTemplate.update(
                        "INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
                        user.getId(),
                        user.getEmail(),
                        user.getName(),
                        user.getCreatedAt()
                    );

                    // Delete from old shard
                    oldTemplate.update(
                        "DELETE FROM users WHERE id = ?",
                        user.getId()
                    );
                }
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Zero Downtime Migration with Dual Writes

@Service
public class DualWriteService {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final Set<String> dualWriteShards = new ConcurrentHashSet<>();

    public void enableDualWrite(String oldShardId, String newShardId) {
        dualWriteShards.add(oldShardId + ":" + newShardId);
    }

    public void createUser(User user) {
        String primaryShard = resolvePrimaryShard(user.getId());
        JdbcTemplate primaryTemplate = shardedJdbcTemplate.getTemplate(primaryShard);

        // Write to primary shard
        primaryTemplate.update(
            "INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
            user.getId(), user.getEmail(), user.getName(), user.getCreatedAt()
        );

        // Check if dual write is enabled
        String dualWriteKey = findDualWriteConfig(primaryShard);
        if (dualWriteKey != null) {
            String secondaryShard = extractSecondaryShard(dualWriteKey);
            JdbcTemplate secondaryTemplate = shardedJdbcTemplate.getTemplate(secondaryShard);

            // Async write to secondary shard
            CompletableFuture.runAsync(() -> {
                try {
                    secondaryTemplate.update(
                        "INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
                        user.getId(), user.getEmail(), user.getName(), user.getCreatedAt()
                    );
                } catch (Exception e) {
                    // Log error but don't fail the primary write
                    logger.error("Dual write failed for user: " + user.getId(), e);
                }
            });
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Real World Example: Complete User Service with Spring Boot

// Domain Model
@Data
@AllArgsConstructor
@NoArgsConstructor
public class User {
    private Long id;
    private String email;
    private String name;
    private Timestamp createdAt;
}

// Repository Interface
public interface UserRepository {
    void save(User user);
    Optional<User> findById(Long id);
    List<User> findByEmail(String email);
    boolean update(User user);
    void delete(Long id);
    Map<String, Long> countPerShard();
}

// Sharded Repository Implementation
@Repository
public class ShardedUserRepository implements UserRepository {

    private final ShardedJdbcTemplate shardedJdbcTemplate;
    private final ExecutorService executorService;

    public ShardedUserRepository(ShardedJdbcTemplate shardedJdbcTemplate) {
        this.shardedJdbcTemplate = shardedJdbcTemplate;
        this.executorService = Executors.newFixedThreadPool(10);
    }

    @Override
    public void save(User user) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());

        template.update(
            "INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
            user.getId(),
            user.getEmail(),
            user.getName(),
            new Timestamp(System.currentTimeMillis())
        );
    }

    @Override
    public Optional<User> findById(Long id) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(id);

        try {
            User user = template.queryForObject(
                "SELECT id, email, name, created_at FROM users WHERE id = ?",
                new Object[]{id},
                (rs, rowNum) -> new User(
                    rs.getLong("id"),
                    rs.getString("email"),
                    rs.getString("name"),
                    rs.getTimestamp("created_at")
                )
            );
            return Optional.ofNullable(user);
        } catch (EmptyResultDataAccessException e) {
            return Optional.empty();
        }
    }

    @Override
    public List<User> findByEmail(String email) {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
        List<Future<List<User>>> futures = new ArrayList<>();

        // Query all shards in parallel
        for (JdbcTemplate template : templates.values()) {
            Future<List<User>> future = executorService.submit(() -> 
                template.query(
                    "SELECT id, email, name, created_at FROM users WHERE email = ?",
                    new Object[]{email},
                    (rs, rowNum) -> new User(
                        rs.getLong("id"),
                        rs.getString("email"),
                        rs.getString("name"),
                        rs.getTimestamp("created_at")
                    )
                )
            );
            futures.add(future);
        }

        // Collect results
        List<User> results = new ArrayList<>();
        for (Future<List<User>> future : futures) {
            try {
                results.addAll(future.get(5, TimeUnit.SECONDS));
            } catch (Exception e) {
                throw new RuntimeException("Error in scatter gather query", e);
            }
        }

        return results;
    }

    @Override
    public boolean update(User user) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());

        int rowsAffected = template.update(
            "UPDATE users SET email = ?, name = ? WHERE id = ?",
            user.getEmail(),
            user.getName(),
            user.getId()
        );

        return rowsAffected > 0;
    }

    @Override
    public void delete(Long id) {
        JdbcTemplate template = shardedJdbcTemplate.getTemplate(id);
        template.update("DELETE FROM users WHERE id = ?", id);
    }

    @Override
    public Map<String, Long> countPerShard() {
        Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
        Map<String, Long> counts = new HashMap<>();

        for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
            Long count = entry.getValue().queryForObject(
                "SELECT COUNT(*) FROM users",
                Long.class
            );
            counts.put(entry.getKey(), count != null ? count : 0L);
        }

        return counts;
    }
}

// Service Layer
@Service
public class UserService {

    private final UserRepository userRepository;

    public UserService(UserRepository userRepository) {
        this.userRepository = userRepository;
    }

    public void createUser(Long id, String email, String name) {
        User user = new User(id, email, name, null);
        userRepository.save(user);
    }

    public User getUser(Long id) {
        return userRepository.findById(id)
            .orElseThrow(() -> new UserNotFoundException("User not found: " + id));
    }

    public List<User> findByEmail(String email) {
        return userRepository.findByEmail(email);
    }

    public void updateUserName(Long id, String newName) {
        User user = getUser(id);
        user.setName(newName);

        if (!userRepository.update(user)) {
            throw new RuntimeException("Failed to update user: " + id);
        }
    }

    public Map<String, Long> getShardDistribution() {
        return userRepository.countPerShard();
    }
}

// REST Controller
@RestController
@RequestMapping("/api/users")
public class UserController {

    private final UserService userService;

    public UserController(UserService userService) {
        this.userService = userService;
    }

    @PostMapping
    public ResponseEntity<Void> createUser(@RequestBody CreateUserRequest request) {
        userService.createUser(request.getId(), request.getEmail(), request.getName());
        return ResponseEntity.status(HttpStatus.CREATED).build();
    }

    @GetMapping("/{id}")
    public ResponseEntity<User> getUser(@PathVariable Long id) {
        User user = userService.getUser(id);
        return ResponseEntity.ok(user);
    }

    @GetMapping("/search")
    public ResponseEntity<List<User>> searchByEmail(@RequestParam String email) {
        List<User> users = userService.findByEmail(email);
        return ResponseEntity.ok(users);
    }

    @PutMapping("/{id}")
    public ResponseEntity<Void> updateUser(
            @PathVariable Long id,
            @RequestBody UpdateUserRequest request) {
        userService.updateUserName(id, request.getName());
        return ResponseEntity.ok().build();
    }

    @GetMapping("/metrics/distribution")
    public ResponseEntity<Map<String, Long>> getDistribution() {
        Map<String, Long> distribution = userService.getShardDistribution();
        return ResponseEntity.ok(distribution);
    }
}

// DTOs
@Data
class CreateUserRequest {
    private Long id;
    private String email;
    private String name;
}

@Data
class UpdateUserRequest {
    private String name;
}
Enter fullscreen mode Exit fullscreen mode

Common Pitfalls and Solutions

Pitfall 1: Wrong Shard Key Choice

Problem: Choosing a low cardinality key or one that creates hotspots.

// Bad: Status field as shard key
int shard = hash(orderStatus) % numShards;
// Most orders are "completed", creating a hot shard
Enter fullscreen mode Exit fullscreen mode

Solution: Choose high cardinality keys that distribute evenly.

// Good: User ID or Order ID
int shard = hash(userId) % numShards;
Enter fullscreen mode Exit fullscreen mode

Pitfall 2: Not Planning for Growth

Problem: Hardcoding number of shards makes it hard to scale.

// Bad: Hardcoded modulo
int shard = hash(userId) % 4;  // What happens when you need 8 shards?
Enter fullscreen mode Exit fullscreen mode

Solution: Use consistent hashing from day one.

Pitfall 3: Connection Pool Exhaustion

Problem: Not properly configuring connection pools per shard.

// Good: Configure Hikari properly per shard
spring.datasource.shard0.hikari.maximum-pool-size=10
spring.datasource.shard0.hikari.minimum-idle=2
spring.datasource.shard0.hikari.connection-timeout=30000
spring.datasource.shard0.hikari.idle-timeout=600000
spring.datasource.shard0.hikari.max-lifetime=1800000
Enter fullscreen mode Exit fullscreen mode

Pitfall 4: Not Handling Shard Failures

Problem: No circuit breaker or retry logic for failing shards.

Solution: Use Resilience4j

@Configuration
public class ResilienceConfig {

    @Bean
    public CircuitBreakerRegistry circuitBreakerRegistry() {
        CircuitBreakerConfig config = CircuitBreakerConfig.custom()
            .failureRateThreshold(50)
            .waitDurationInOpenState(Duration.ofMillis(1000))
            .slidingWindowSize(10)
            .build();

        return CircuitBreakerRegistry.of(config);
    }
}

@Service
public class ResilientUserRepository {

    private final ShardedUserRepository delegate;
    private final CircuitBreakerRegistry circuitBreakerRegistry;

    public Optional<User> findById(Long id) {
        String shardId = resolveShardId(id);
        CircuitBreaker circuitBreaker = circuitBreakerRegistry
            .circuitBreaker("shard_" + shardId);

        return circuitBreaker.executeSupplier(() -> delegate.findById(id));
    }
}
Enter fullscreen mode Exit fullscreen mode

Pitfall 5: No Backup Strategy Per Shard

Problem: Treating all shards as one unit for backups.

Solution: Each shard needs its own backup strategy.

@Service
public class ShardBackupService {

    private final Map<String, DataSource> shardDataSources;
    private final ExecutorService executorService;

    public void backupAllShards() {
        List<Future<String>> futures = new ArrayList<>();

        for (Map.Entry<String, DataSource> entry : shardDataSources.entrySet()) {
            String shardId = entry.getKey();

            Future<String> future = executorService.submit(() -> {
                String timestamp = LocalDateTime.now()
                    .format(DateTimeFormatter.ISO_LOCAL_DATE_TIME);
                String backupFile = String.format("backup_%s_%s.sql", shardId, timestamp);

                // Execute pg_dump or similar
                ProcessBuilder pb = new ProcessBuilder(
                    "pg_dump",
                    "--host=localhost",
                    "--dbname=" + shardId,
                    "--file=" + backupFile
                );
                Process process = pb.start();
                process.waitFor();

                return backupFile;
            });

            futures.add(future);
        }

        // Wait for all backups to complete
        for (Future<String> future : futures) {
            try {
                String backupFile = future.get();
                logger.info("Backup completed: " + backupFile);
            } catch (Exception e) {
                logger.error("Backup failed", e);
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Conclusion

Database sharding with Spring Boot is a powerful technique for scaling beyond what a single database can handle, but it comes with significant complexity. The key takeaways:

  • Choose your shard key carefully based on your query patterns
  • Use consistent hashing to make resharding easier
  • Leverage Spring Boot's multi datasource capabilities
  • Design your schema to minimize cross shard queries
  • Monitor shard balance and performance with Actuator
  • Have a clear migration and rollback strategy
  • Accept eventual consistency where possible to avoid distributed transactions
  • Use connection pooling wisely per shard

Sharding is not something to implement until you need it, but when you do need it, understanding these fundamentals will help you build a scalable, maintainable Spring Boot application.

Remember: the best sharding strategy is one that aligns with your specific access patterns and business requirements. There's no one size fits all solution.

Now go build something that scales!

Top comments (0)