Sharding is one of those topics that sounds intimidating until you actually understand what's happening under the hood. If you're a senior Java developer encountering sharding for the first time, or if you need a refresher on implementing it with Spring Boot, this guide covers everything from basic concepts to production ready implementations.
Table of Contents
- What is Database Sharding?
- Why Shard?
- Core Sharding Concepts
- Hash Function Deep Dive
- Spring Boot Implementation Patterns
- Query Patterns and Challenges
- Transactions Across Shards
- Monitoring and Observability with Spring Boot Actuator
- Rebalancing and Resharding
- Real World Example: Complete User Service with Spring Boot
- Common Pitfalls and Solutions
- Conclusion
What is Database Sharding?
Database sharding is a horizontal partitioning strategy where you split your data across multiple database instances (shards). Instead of having one massive database trying to handle all your traffic and data, you distribute the load across multiple smaller databases.
Think of it like this: instead of one gigantic library holding every book ever written, you have multiple libraries where books are distributed based on some criteria (maybe by author's last name, subject, or publication date).
Why Shard?
Before diving into implementation, let's talk about when you actually need sharding:
Vertical scaling hits a wall. You can only add so much RAM, CPU, and faster disks to a single database server before you hit physical or cost limitations.
Query performance degrades. Even with proper indexing, queries slow down as your dataset grows. A query on a 10TB table is inherently slower than the same query on a 100GB table.
Write throughput becomes a bottleneck. A single database can only handle so many writes per second. When your application needs to process millions of writes per minute, one database won't cut it.
High availability requirements. Sharding can improve availability because a failure in one shard doesn't take down your entire system.
Core Sharding Concepts
The Shard Key
The shard key is the field you use to determine which shard holds a particular piece of data. This is absolutely critical to get right because changing it later is extremely painful.
Common shard key choices:
- User ID
- Tenant ID (for multi tenant applications)
- Geographic region
- Timestamp ranges (for time series data)
The shard key should have these properties:
- High cardinality: Many unique values to distribute data evenly
- Query friendly: Your most common queries should be able to target a single shard
- Stable: The value shouldn't change frequently (or at all)
Sharding Strategies
There are several ways to distribute data across shards:
1. Range Based Sharding
You divide data based on ranges of the shard key.
Shard 1: user_id 1 to 1,000,000
Shard 2: user_id 1,000,001 to 2,000,000
Shard 3: user_id 2,000,001 to 3,000,000
Pros: Simple to implement, range queries are efficient
Cons: Can lead to hotspots if data isn't uniformly distributed
2. Hash Based Sharding
You apply a hash function to the shard key and use the result to determine the shard.
public int getShard(Long userId, int numShards) {
return Math.abs(userId.hashCode()) % numShards;
}
Pros: Excellent data distribution
Cons: Range queries require hitting all shards
3. Directory Based Sharding
You maintain a lookup table that maps shard keys to specific shards.
user_id -> shard mapping stored in a lookup service
1234 -> shard_2
5678 -> shard_1
9012 -> shard_3
Pros: Flexible, can rebalance without changing keys
Cons: Lookup table becomes a single point of failure, adds latency
4. Geographic Sharding
Data is sharded based on geographic location.
Users in North America -> shard_us
Users in Europe -> shard_eu
Users in Asia -> shard_asia
Pros: Reduces latency for geo distributed users, helps with data residency compliance
Cons: Can lead to uneven distribution
Hash Function Deep Dive
Let's get into the details of how hash based sharding actually works in Java.
Choosing a Hash Function
You need a hash function that:
- Distributes keys uniformly across shards
- Is fast to compute
- Has minimal collisions
- Is deterministic (same input always gives same output)
Common choices:
MD5: Fast and provides good distribution
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.math.BigInteger;
public class ShardingUtils {
public static int md5HashShard(String key, int numShards) {
try {
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] messageDigest = md.digest(key.getBytes());
BigInteger bigInt = new BigInteger(1, messageDigest);
return Math.abs(bigInt.intValue()) % numShards;
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("MD5 algorithm not found", e);
}
}
}
MurmurHash: Faster than MD5, excellent distribution
import com.google.common.hash.Hashing;
import java.nio.charset.StandardCharsets;
public class ShardingUtils {
public static int murmurHashShard(String key, int numShards) {
int hash = Hashing.murmur3_32_fixed()
.hashString(key, StandardCharsets.UTF_8)
.asInt();
return Math.abs(hash) % numShards;
}
}
CRC32: Very fast, good enough for most use cases
import java.util.zip.CRC32;
import java.nio.charset.StandardCharsets;
public class ShardingUtils {
public static int crc32HashShard(String key, int numShards) {
CRC32 crc = new CRC32();
crc.update(key.getBytes(StandardCharsets.UTF_8));
return (int) (Math.abs(crc.getValue()) % numShards);
}
}
The Modulo Problem
The simple modulo approach has a major flaw: when you add or remove shards, almost all keys get reassigned to different shards.
// With 3 shards
hash(123) % 3 = 0 // Goes to shard 0
// Add a shard, now 4 shards
hash(123) % 4 = 3 // Now goes to shard 3!
This means massive data migration when scaling.
Consistent Hashing to the Rescue
Consistent hashing solves the resharding problem by minimizing the number of keys that need to move when shards are added or removed.
Here's how it works conceptually:
- Imagine a ring (circle) with values from 0 to 2^32
- Each shard is assigned multiple points on this ring (virtual nodes)
- Each key is hashed to a point on the ring
- The key belongs to the first shard found moving clockwise from the key's position
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
public class ConsistentHash<T> {
private final int virtualNodes;
private final TreeMap<Long, T> ring;
private final MessageDigest md;
public ConsistentHash(int virtualNodes) {
this.virtualNodes = virtualNodes;
this.ring = new TreeMap<>();
try {
this.md = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("MD5 algorithm not found", e);
}
}
public void addNode(T node) {
for (int i = 0; i < virtualNodes; i++) {
String virtualKey = node.toString() + ":" + i;
long hash = hash(virtualKey);
ring.put(hash, node);
}
}
public void removeNode(T node) {
for (int i = 0; i < virtualNodes; i++) {
String virtualKey = node.toString() + ":" + i;
long hash = hash(virtualKey);
ring.remove(hash);
}
}
public T getNode(String key) {
if (ring.isEmpty()) {
return null;
}
long hash = hash(key);
// Find first node with hash >= key hash
Map.Entry<Long, T> entry = ring.ceilingEntry(hash);
// If not found, wrap around to first node
if (entry == null) {
entry = ring.firstEntry();
}
return entry.getValue();
}
private long hash(String key) {
md.reset();
md.update(key.getBytes());
byte[] digest = md.digest();
// Take first 8 bytes to create a long
long hash = 0;
for (int i = 0; i < 8; i++) {
hash = (hash << 8) | (digest[i] & 0xFF);
}
return hash;
}
public int size() {
return ring.size() / virtualNodes;
}
}
// Usage
ConsistentHash<String> ch = new ConsistentHash<>(150);
ch.addNode("shard_1");
ch.addNode("shard_2");
ch.addNode("shard_3");
String shard = ch.getNode("user_12345");
System.out.println("Key belongs to: " + shard);
With consistent hashing, adding or removing a shard only affects approximately 1/N of the keys (where N is the number of shards).
Spring Boot Implementation Patterns
Setting Up Multiple DataSources
First, let's configure multiple datasources in Spring Boot:
// application.yml
spring:
datasource:
shard0:
jdbc-url: jdbc:postgresql://db1.example.com:5432/shard_0
username: app_user
password: secret
driver-class-name: org.postgresql.Driver
hikari:
maximum-pool-size: 10
minimum-idle: 2
shard1:
jdbc-url: jdbc:postgresql://db2.example.com:5432/shard_1
username: app_user
password: secret
driver-class-name: org.postgresql.Driver
hikari:
maximum-pool-size: 10
minimum-idle: 2
shard2:
jdbc-url: jdbc:postgresql://db3.example.com:5432/shard_2
username: app_user
password: secret
driver-class-name: org.postgresql.Driver
hikari:
maximum-pool-size: 10
minimum-idle: 2
DataSource Configuration
import com.zaxxer.hikari.HikariDataSource;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.sql.DataSource;
import java.util.HashMap;
import java.util.Map;
@Configuration
public class DataSourceConfig {
@Bean
@ConfigurationProperties("spring.datasource.shard0")
public DataSource shard0DataSource() {
return DataSourceBuilder.create()
.type(HikariDataSource.class)
.build();
}
@Bean
@ConfigurationProperties("spring.datasource.shard1")
public DataSource shard1DataSource() {
return DataSourceBuilder.create()
.type(HikariDataSource.class)
.build();
}
@Bean
@ConfigurationProperties("spring.datasource.shard2")
public DataSource shard2DataSource() {
return DataSourceBuilder.create()
.type(HikariDataSource.class)
.build();
}
@Bean
public Map<String, DataSource> shardDataSources(
DataSource shard0DataSource,
DataSource shard1DataSource,
DataSource shard2DataSource) {
Map<String, DataSource> dataSources = new HashMap<>();
dataSources.put("shard_0", shard0DataSource);
dataSources.put("shard_1", shard1DataSource);
dataSources.put("shard_2", shard2DataSource);
return dataSources;
}
}
Shard Resolver Service
import org.springframework.stereotype.Service;
@Service
public class ShardResolver {
private final ConsistentHash<String> consistentHash;
public ShardResolver() {
this.consistentHash = new ConsistentHash<>(150);
consistentHash.addNode("shard_0");
consistentHash.addNode("shard_1");
consistentHash.addNode("shard_2");
}
public String resolveShardId(Long userId) {
return consistentHash.getNode(userId.toString());
}
public String resolveShardId(String key) {
return consistentHash.getNode(key);
}
}
Sharded JDBC Template
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.util.Map;
@Component
public class ShardedJdbcTemplate {
private final Map<String, JdbcTemplate> shardTemplates;
private final ShardResolver shardResolver;
public ShardedJdbcTemplate(
Map<String, DataSource> shardDataSources,
ShardResolver shardResolver) {
this.shardResolver = shardResolver;
this.shardTemplates = new HashMap<>();
shardDataSources.forEach((shardId, dataSource) -> {
this.shardTemplates.put(shardId, new JdbcTemplate(dataSource));
});
}
public JdbcTemplate getTemplate(Long userId) {
String shardId = shardResolver.resolveShardId(userId);
return shardTemplates.get(shardId);
}
public JdbcTemplate getTemplate(String shardId) {
return shardTemplates.get(shardId);
}
public Map<String, JdbcTemplate> getAllTemplates() {
return new HashMap<>(shardTemplates);
}
}
Query Patterns and Challenges
Single Shard Queries
When your query includes the shard key, you can route directly to one shard.
@Service
public class UserService {
private final ShardedJdbcTemplate shardedJdbcTemplate;
public UserService(ShardedJdbcTemplate shardedJdbcTemplate) {
this.shardedJdbcTemplate = shardedJdbcTemplate;
}
public User getUserById(Long userId) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(userId);
return template.queryForObject(
"SELECT id, email, name, created_at FROM users WHERE id = ?",
new Object[]{userId},
(rs, rowNum) -> new User(
rs.getLong("id"),
rs.getString("email"),
rs.getString("name"),
rs.getTimestamp("created_at")
)
);
}
public void createUser(User user) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());
template.update(
"INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
user.getId(),
user.getEmail(),
user.getName(),
new Timestamp(System.currentTimeMillis())
);
}
}
Cross Shard Queries (Scatter Gather)
When your query doesn't include the shard key, you need to query all shards and merge results.
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.ArrayList;
import java.util.concurrent.*;
@Service
public class UserService {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final ExecutorService executorService;
public UserService(ShardedJdbcTemplate shardedJdbcTemplate) {
this.shardedJdbcTemplate = shardedJdbcTemplate;
this.executorService = Executors.newFixedThreadPool(10);
}
public List<User> findUsersByEmail(String email) {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
List<Future<List<User>>> futures = new ArrayList<>();
// Query all shards in parallel
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
Future<List<User>> future = executorService.submit(() -> {
return entry.getValue().query(
"SELECT id, email, name, created_at FROM users WHERE email = ?",
new Object[]{email},
(rs, rowNum) -> new User(
rs.getLong("id"),
rs.getString("email"),
rs.getString("name"),
rs.getTimestamp("created_at")
)
);
});
futures.add(future);
}
// Collect results from all shards
List<User> results = new ArrayList<>();
for (Future<List<User>> future : futures) {
try {
results.addAll(future.get(5, TimeUnit.SECONDS));
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException("Error executing scatter gather query", e);
}
}
return results;
}
}
Joins Across Shards
This is where sharding gets painful. Joins across shards are expensive and complex.
Solution 1: Denormalize
Store redundant data to avoid cross shard joins.
-- Instead of joining users and orders across shards
-- Store user info in the orders table
CREATE TABLE orders (
id BIGINT,
user_id BIGINT,
user_name VARCHAR(255), -- Denormalized
user_email VARCHAR(255), -- Denormalized
total DECIMAL(10,2)
)
Solution 2: Application Level Joins
Fetch from multiple shards and join in memory.
public class OrderWithUser {
private Order order;
private User user;
// getters and setters
}
@Service
public class OrderService {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final UserService userService;
public OrderWithUser getOrderWithUser(Long orderId, Long userId) {
// Get order from its shard
JdbcTemplate orderTemplate = shardedJdbcTemplate.getTemplate(orderId);
Order order = orderTemplate.queryForObject(
"SELECT * FROM orders WHERE id = ?",
new Object[]{orderId},
(rs, rowNum) -> mapOrder(rs)
);
// Get user from their shard
User user = userService.getUserById(userId);
OrderWithUser result = new OrderWithUser();
result.setOrder(order);
result.setUser(user);
return result;
}
}
Solution 3: Shard by Relationship
Co locate related data on the same shard.
// Users and their orders always on the same shard
// Use userId as shard key for both tables
@Service
public class OrderService {
public List<OrderWithUser> getUserOrders(Long userId) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(userId);
// Both tables on the same shard, joins work normally
return template.query(
"SELECT u.*, o.* FROM users u " +
"JOIN orders o ON u.id = o.user_id " +
"WHERE u.id = ?",
new Object[]{userId},
(rs, rowNum) -> mapOrderWithUser(rs)
);
}
}
Transactions Across Shards
Distributed transactions are complex and slow. ACID guarantees across shards typically require two phase commit (2PC).
Two Phase Commit with Spring
import org.springframework.stereotype.Service;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import java.util.*;
@Service
public class DistributedTransactionService {
private final Map<String, PlatformTransactionManager> transactionManagers;
private final ShardedJdbcTemplate shardedJdbcTemplate;
public DistributedTransactionService(
Map<String, PlatformTransactionManager> transactionManagers,
ShardedJdbcTemplate shardedJdbcTemplate) {
this.transactionManagers = transactionManagers;
this.shardedJdbcTemplate = shardedJdbcTemplate;
}
public void executeTwoPhaseCommit(Map<String, Runnable> shardOperations) {
Map<String, TransactionStatus> transactions = new HashMap<>();
try {
// Phase 1: Prepare (execute all operations)
for (Map.Entry<String, Runnable> entry : shardOperations.entrySet()) {
String shardId = entry.getKey();
PlatformTransactionManager tm = transactionManagers.get(shardId);
TransactionStatus status = tm.getTransaction(
new DefaultTransactionDefinition()
);
transactions.put(shardId, status);
try {
entry.getValue().run();
} catch (Exception e) {
throw new RuntimeException("Failed to execute on shard: " + shardId, e);
}
}
// Phase 2: Commit all
for (Map.Entry<String, TransactionStatus> entry : transactions.entrySet()) {
String shardId = entry.getKey();
PlatformTransactionManager tm = transactionManagers.get(shardId);
tm.commit(entry.getValue());
}
} catch (Exception e) {
// Rollback all prepared transactions
for (Map.Entry<String, TransactionStatus> entry : transactions.entrySet()) {
try {
String shardId = entry.getKey();
PlatformTransactionManager tm = transactionManagers.get(shardId);
tm.rollback(entry.getValue());
} catch (Exception rollbackException) {
// Log rollback failure
}
}
throw new RuntimeException("Distributed transaction failed", e);
}
}
}
Problems with 2PC:
- Slow (multiple network round trips)
- Coordinator failure can leave system in limbo
- Locks held for long duration
Better Approach: Saga Pattern
import org.springframework.stereotype.Service;
import java.util.*;
@Service
public class SagaTransactionService {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final SagaLogRepository sagaLogRepository;
public void transferMoney(Long fromUserId, Long toUserId, BigDecimal amount) {
String sagaId = UUID.randomUUID().toString();
try {
// Step 1: Debit from user
JdbcTemplate fromTemplate = shardedJdbcTemplate.getTemplate(fromUserId);
fromTemplate.update(
"UPDATE accounts SET balance = balance - ? WHERE user_id = ?",
amount, fromUserId
);
sagaLogRepository.logStep(sagaId, "debit_complete", fromUserId);
// Step 2: Credit to user
JdbcTemplate toTemplate = shardedJdbcTemplate.getTemplate(toUserId);
toTemplate.update(
"UPDATE accounts SET balance = balance + ? WHERE user_id = ?",
amount, toUserId
);
sagaLogRepository.logStep(sagaId, "credit_complete", toUserId);
sagaLogRepository.markComplete(sagaId);
} catch (Exception e) {
compensateTransfer(sagaId, fromUserId, toUserId, amount);
throw new RuntimeException("Transfer failed, compensation executed", e);
}
}
private void compensateTransfer(
String sagaId,
Long fromUserId,
Long toUserId,
BigDecimal amount) {
List<SagaStep> completedSteps = sagaLogRepository.getCompletedSteps(sagaId);
for (SagaStep step : completedSteps) {
if ("debit_complete".equals(step.getStepName())) {
// Reverse the debit
JdbcTemplate template = shardedJdbcTemplate.getTemplate(fromUserId);
template.update(
"UPDATE accounts SET balance = balance + ? WHERE user_id = ?",
amount, fromUserId
);
}
}
sagaLogRepository.markCompensated(sagaId);
}
}
Monitoring and Observability with Spring Boot Actuator
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tags;
import org.springframework.stereotype.Component;
import java.util.Map;
@Component
public class ShardMetricsCollector {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final MeterRegistry meterRegistry;
public ShardMetricsCollector(
ShardedJdbcTemplate shardedJdbcTemplate,
MeterRegistry meterRegistry) {
this.shardedJdbcTemplate = shardedJdbcTemplate;
this.meterRegistry = meterRegistry;
}
@Scheduled(fixedRate = 60000) // Every minute
public void collectShardMetrics() {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
String shardId = entry.getKey();
JdbcTemplate template = entry.getValue();
// Collect row count
Long rowCount = template.queryForObject(
"SELECT COUNT(*) FROM users",
Long.class
);
meterRegistry.gauge(
"shard.row.count",
Tags.of("shard", shardId),
rowCount != null ? rowCount : 0
);
// Collect table size
Long tableSize = template.queryForObject(
"SELECT pg_total_relation_size('users')",
Long.class
);
meterRegistry.gauge(
"shard.table.size.bytes",
Tags.of("shard", shardId),
tableSize != null ? tableSize : 0
);
}
}
public Map<String, ShardMetrics> getShardMetrics() {
Map<String, ShardMetrics> metrics = new HashMap<>();
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
String shardId = entry.getKey();
JdbcTemplate template = entry.getValue();
ShardMetrics shardMetrics = new ShardMetrics();
shardMetrics.setShardId(shardId);
shardMetrics.setRowCount(
template.queryForObject("SELECT COUNT(*) FROM users", Long.class)
);
metrics.put(shardId, shardMetrics);
}
return metrics;
}
}
@Data
class ShardMetrics {
private String shardId;
private Long rowCount;
private Long tableSizeBytes;
private Double queriesPerSecond;
private Double avgQueryTimeMs;
}
Health Check Endpoint
import org.springframework.boot.actuate.health.Health;
import org.springframework.boot.actuate.health.HealthIndicator;
import org.springframework.stereotype.Component;
@Component
public class ShardHealthIndicator implements HealthIndicator {
private final ShardedJdbcTemplate shardedJdbcTemplate;
public ShardHealthIndicator(ShardedJdbcTemplate shardedJdbcTemplate) {
this.shardedJdbcTemplate = shardedJdbcTemplate;
}
@Override
public Health health() {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
Map<String, String> shardStatus = new HashMap<>();
boolean allHealthy = true;
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
String shardId = entry.getKey();
try {
entry.getValue().queryForObject("SELECT 1", Integer.class);
shardStatus.put(shardId, "UP");
} catch (Exception e) {
shardStatus.put(shardId, "DOWN");
allHealthy = false;
}
}
if (allHealthy) {
return Health.up().withDetails(shardStatus).build();
} else {
return Health.down().withDetails(shardStatus).build();
}
}
}
Rebalancing and Resharding
Eventually, you'll need to add more shards or rebalance data.
Adding New Shards
@Service
public class ShardManagementService {
private final ConsistentHash<String> consistentHash;
private final Map<String, DataSource> shardDataSources;
private final ShardedJdbcTemplate shardedJdbcTemplate;
public void addShard(String newShardId, DataSource newDataSource) {
// Add to consistent hash ring
consistentHash.addNode(newShardId);
// Add to datasources
shardDataSources.put(newShardId, newDataSource);
// Migrate affected keys
migrateKeysToNewShard(newShardId);
}
private void migrateKeysToNewShard(String newShardId) {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
String oldShardId = entry.getKey();
if (oldShardId.equals(newShardId)) {
continue;
}
JdbcTemplate oldTemplate = entry.getValue();
// Get all users from old shard
List<User> users = oldTemplate.query(
"SELECT id, email, name, created_at FROM users",
(rs, rowNum) -> new User(
rs.getLong("id"),
rs.getString("email"),
rs.getString("name"),
rs.getTimestamp("created_at")
)
);
JdbcTemplate newTemplate = shardedJdbcTemplate.getTemplate(newShardId);
for (User user : users) {
// Check if user should be on new shard
String targetShard = consistentHash.getNode(user.getId().toString());
if (targetShard.equals(newShardId)) {
// Move to new shard
newTemplate.update(
"INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
user.getId(),
user.getEmail(),
user.getName(),
user.getCreatedAt()
);
// Delete from old shard
oldTemplate.update(
"DELETE FROM users WHERE id = ?",
user.getId()
);
}
}
}
}
}
Zero Downtime Migration with Dual Writes
@Service
public class DualWriteService {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final Set<String> dualWriteShards = new ConcurrentHashSet<>();
public void enableDualWrite(String oldShardId, String newShardId) {
dualWriteShards.add(oldShardId + ":" + newShardId);
}
public void createUser(User user) {
String primaryShard = resolvePrimaryShard(user.getId());
JdbcTemplate primaryTemplate = shardedJdbcTemplate.getTemplate(primaryShard);
// Write to primary shard
primaryTemplate.update(
"INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
user.getId(), user.getEmail(), user.getName(), user.getCreatedAt()
);
// Check if dual write is enabled
String dualWriteKey = findDualWriteConfig(primaryShard);
if (dualWriteKey != null) {
String secondaryShard = extractSecondaryShard(dualWriteKey);
JdbcTemplate secondaryTemplate = shardedJdbcTemplate.getTemplate(secondaryShard);
// Async write to secondary shard
CompletableFuture.runAsync(() -> {
try {
secondaryTemplate.update(
"INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
user.getId(), user.getEmail(), user.getName(), user.getCreatedAt()
);
} catch (Exception e) {
// Log error but don't fail the primary write
logger.error("Dual write failed for user: " + user.getId(), e);
}
});
}
}
}
Real World Example: Complete User Service with Spring Boot
// Domain Model
@Data
@AllArgsConstructor
@NoArgsConstructor
public class User {
private Long id;
private String email;
private String name;
private Timestamp createdAt;
}
// Repository Interface
public interface UserRepository {
void save(User user);
Optional<User> findById(Long id);
List<User> findByEmail(String email);
boolean update(User user);
void delete(Long id);
Map<String, Long> countPerShard();
}
// Sharded Repository Implementation
@Repository
public class ShardedUserRepository implements UserRepository {
private final ShardedJdbcTemplate shardedJdbcTemplate;
private final ExecutorService executorService;
public ShardedUserRepository(ShardedJdbcTemplate shardedJdbcTemplate) {
this.shardedJdbcTemplate = shardedJdbcTemplate;
this.executorService = Executors.newFixedThreadPool(10);
}
@Override
public void save(User user) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());
template.update(
"INSERT INTO users (id, email, name, created_at) VALUES (?, ?, ?, ?)",
user.getId(),
user.getEmail(),
user.getName(),
new Timestamp(System.currentTimeMillis())
);
}
@Override
public Optional<User> findById(Long id) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(id);
try {
User user = template.queryForObject(
"SELECT id, email, name, created_at FROM users WHERE id = ?",
new Object[]{id},
(rs, rowNum) -> new User(
rs.getLong("id"),
rs.getString("email"),
rs.getString("name"),
rs.getTimestamp("created_at")
)
);
return Optional.ofNullable(user);
} catch (EmptyResultDataAccessException e) {
return Optional.empty();
}
}
@Override
public List<User> findByEmail(String email) {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
List<Future<List<User>>> futures = new ArrayList<>();
// Query all shards in parallel
for (JdbcTemplate template : templates.values()) {
Future<List<User>> future = executorService.submit(() ->
template.query(
"SELECT id, email, name, created_at FROM users WHERE email = ?",
new Object[]{email},
(rs, rowNum) -> new User(
rs.getLong("id"),
rs.getString("email"),
rs.getString("name"),
rs.getTimestamp("created_at")
)
)
);
futures.add(future);
}
// Collect results
List<User> results = new ArrayList<>();
for (Future<List<User>> future : futures) {
try {
results.addAll(future.get(5, TimeUnit.SECONDS));
} catch (Exception e) {
throw new RuntimeException("Error in scatter gather query", e);
}
}
return results;
}
@Override
public boolean update(User user) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(user.getId());
int rowsAffected = template.update(
"UPDATE users SET email = ?, name = ? WHERE id = ?",
user.getEmail(),
user.getName(),
user.getId()
);
return rowsAffected > 0;
}
@Override
public void delete(Long id) {
JdbcTemplate template = shardedJdbcTemplate.getTemplate(id);
template.update("DELETE FROM users WHERE id = ?", id);
}
@Override
public Map<String, Long> countPerShard() {
Map<String, JdbcTemplate> templates = shardedJdbcTemplate.getAllTemplates();
Map<String, Long> counts = new HashMap<>();
for (Map.Entry<String, JdbcTemplate> entry : templates.entrySet()) {
Long count = entry.getValue().queryForObject(
"SELECT COUNT(*) FROM users",
Long.class
);
counts.put(entry.getKey(), count != null ? count : 0L);
}
return counts;
}
}
// Service Layer
@Service
public class UserService {
private final UserRepository userRepository;
public UserService(UserRepository userRepository) {
this.userRepository = userRepository;
}
public void createUser(Long id, String email, String name) {
User user = new User(id, email, name, null);
userRepository.save(user);
}
public User getUser(Long id) {
return userRepository.findById(id)
.orElseThrow(() -> new UserNotFoundException("User not found: " + id));
}
public List<User> findByEmail(String email) {
return userRepository.findByEmail(email);
}
public void updateUserName(Long id, String newName) {
User user = getUser(id);
user.setName(newName);
if (!userRepository.update(user)) {
throw new RuntimeException("Failed to update user: " + id);
}
}
public Map<String, Long> getShardDistribution() {
return userRepository.countPerShard();
}
}
// REST Controller
@RestController
@RequestMapping("/api/users")
public class UserController {
private final UserService userService;
public UserController(UserService userService) {
this.userService = userService;
}
@PostMapping
public ResponseEntity<Void> createUser(@RequestBody CreateUserRequest request) {
userService.createUser(request.getId(), request.getEmail(), request.getName());
return ResponseEntity.status(HttpStatus.CREATED).build();
}
@GetMapping("/{id}")
public ResponseEntity<User> getUser(@PathVariable Long id) {
User user = userService.getUser(id);
return ResponseEntity.ok(user);
}
@GetMapping("/search")
public ResponseEntity<List<User>> searchByEmail(@RequestParam String email) {
List<User> users = userService.findByEmail(email);
return ResponseEntity.ok(users);
}
@PutMapping("/{id}")
public ResponseEntity<Void> updateUser(
@PathVariable Long id,
@RequestBody UpdateUserRequest request) {
userService.updateUserName(id, request.getName());
return ResponseEntity.ok().build();
}
@GetMapping("/metrics/distribution")
public ResponseEntity<Map<String, Long>> getDistribution() {
Map<String, Long> distribution = userService.getShardDistribution();
return ResponseEntity.ok(distribution);
}
}
// DTOs
@Data
class CreateUserRequest {
private Long id;
private String email;
private String name;
}
@Data
class UpdateUserRequest {
private String name;
}
Common Pitfalls and Solutions
Pitfall 1: Wrong Shard Key Choice
Problem: Choosing a low cardinality key or one that creates hotspots.
// Bad: Status field as shard key
int shard = hash(orderStatus) % numShards;
// Most orders are "completed", creating a hot shard
Solution: Choose high cardinality keys that distribute evenly.
// Good: User ID or Order ID
int shard = hash(userId) % numShards;
Pitfall 2: Not Planning for Growth
Problem: Hardcoding number of shards makes it hard to scale.
// Bad: Hardcoded modulo
int shard = hash(userId) % 4; // What happens when you need 8 shards?
Solution: Use consistent hashing from day one.
Pitfall 3: Connection Pool Exhaustion
Problem: Not properly configuring connection pools per shard.
// Good: Configure Hikari properly per shard
spring.datasource.shard0.hikari.maximum-pool-size=10
spring.datasource.shard0.hikari.minimum-idle=2
spring.datasource.shard0.hikari.connection-timeout=30000
spring.datasource.shard0.hikari.idle-timeout=600000
spring.datasource.shard0.hikari.max-lifetime=1800000
Pitfall 4: Not Handling Shard Failures
Problem: No circuit breaker or retry logic for failing shards.
Solution: Use Resilience4j
@Configuration
public class ResilienceConfig {
@Bean
public CircuitBreakerRegistry circuitBreakerRegistry() {
CircuitBreakerConfig config = CircuitBreakerConfig.custom()
.failureRateThreshold(50)
.waitDurationInOpenState(Duration.ofMillis(1000))
.slidingWindowSize(10)
.build();
return CircuitBreakerRegistry.of(config);
}
}
@Service
public class ResilientUserRepository {
private final ShardedUserRepository delegate;
private final CircuitBreakerRegistry circuitBreakerRegistry;
public Optional<User> findById(Long id) {
String shardId = resolveShardId(id);
CircuitBreaker circuitBreaker = circuitBreakerRegistry
.circuitBreaker("shard_" + shardId);
return circuitBreaker.executeSupplier(() -> delegate.findById(id));
}
}
Pitfall 5: No Backup Strategy Per Shard
Problem: Treating all shards as one unit for backups.
Solution: Each shard needs its own backup strategy.
@Service
public class ShardBackupService {
private final Map<String, DataSource> shardDataSources;
private final ExecutorService executorService;
public void backupAllShards() {
List<Future<String>> futures = new ArrayList<>();
for (Map.Entry<String, DataSource> entry : shardDataSources.entrySet()) {
String shardId = entry.getKey();
Future<String> future = executorService.submit(() -> {
String timestamp = LocalDateTime.now()
.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME);
String backupFile = String.format("backup_%s_%s.sql", shardId, timestamp);
// Execute pg_dump or similar
ProcessBuilder pb = new ProcessBuilder(
"pg_dump",
"--host=localhost",
"--dbname=" + shardId,
"--file=" + backupFile
);
Process process = pb.start();
process.waitFor();
return backupFile;
});
futures.add(future);
}
// Wait for all backups to complete
for (Future<String> future : futures) {
try {
String backupFile = future.get();
logger.info("Backup completed: " + backupFile);
} catch (Exception e) {
logger.error("Backup failed", e);
}
}
}
}
Conclusion
Database sharding with Spring Boot is a powerful technique for scaling beyond what a single database can handle, but it comes with significant complexity. The key takeaways:
- Choose your shard key carefully based on your query patterns
- Use consistent hashing to make resharding easier
- Leverage Spring Boot's multi datasource capabilities
- Design your schema to minimize cross shard queries
- Monitor shard balance and performance with Actuator
- Have a clear migration and rollback strategy
- Accept eventual consistency where possible to avoid distributed transactions
- Use connection pooling wisely per shard
Sharding is not something to implement until you need it, but when you do need it, understanding these fundamentals will help you build a scalable, maintainable Spring Boot application.
Remember: the best sharding strategy is one that aligns with your specific access patterns and business requirements. There's no one size fits all solution.
Now go build something that scales!
Top comments (0)