DEV Community

wellallyTech
wellallyTech

Posted on

Privacy-First AI: Training ECG Models with Federated Learning and Flower 🩺🛡️

In the era of data-driven medicine, we face a massive paradox: Artificial Intelligence needs massive datasets to save lives, but medical data—especially ECG (Electrocardiogram) readings—is incredibly sensitive. Sharing raw heart rate data across hospitals isn't just a technical challenge; it's a privacy nightmare governed by strict regulations like HIPAA and GDPR.

Enter Federated Learning (FL). Instead of bringing data to the model, we bring the model to the data. In this tutorial, we will explore how to use the Flower (flwr) framework and PyTorch to train a collaborative model that identifies Premature Ventricular Contractions (PVC) without a single byte of raw patient data ever leaving the local edge device. By leveraging Privacy-Preserving AI and gRPC communication, we can build high-performance diagnostic tools while keeping individual privacy intact.

The Architecture: How Federated Learning Works

In a typical centralized setup, you upload everything to a cloud server. In Federated Learning, the "Global Model" lives on a central server, but the training happens on "Clients" (like hospital servers or even wearable devices).

sequenceDiagram
    participant S as Federated Server (Global)
    participant C1 as Hospital A (Edge Client)
    participant C2 as Hospital B (Edge Client)

    Note over S: Initialize Global Weights
    S->>C1: Send Global Weights (W1)
    S->>C2: Send Global Weights (W1)

    Note over C1: Local Training on Private ECGs
    Note over C2: Local Training on Private ECGs

    C1->>S: Send Weight Updates (ΔW_A)
    C2->>S: Send Weight Updates (ΔW_B)

    Note over S: Aggregate Updates (FedAvg)
    Note over S: Update Global Weights (W2)
    S->>C1: Send Refined Model (W2)
    S->>C2: Send Refined Model (W2)
Enter fullscreen mode Exit fullscreen mode

Prerequisites

To follow along with this advanced-level guide, you'll need:

  • Tech Stack: Python 3.9+, PyTorch (for the neural network), and Flower (flwr) for the federation logic.
  • Knowledge: Basic understanding of 1D Convolutional Neural Networks (CNNs).

Step 1: Defining the Heart (The PyTorch Model)

We need a model capable of processing 1D signal data. A simple CNN is perfect for detecting the rhythmic anomalies of a PVC.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ECGNet(nn.Module):
    def __init__(self):
        super(ECGNet, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=5)
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(16 * 90, 128) # Assuming 185-length ECG segments
        self.fc2 = nn.Linear(128, 2) # Normal vs PVC

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 90)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model, criterion, and optimizer
net = ECGNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
Enter fullscreen mode Exit fullscreen mode

Step 2: Creating the Flower Client

The Flower Client is the brains of the operation at the edge. It handles the local training loop and ensures only weights (gradients) are sent back over gRPC.

import flwr as fl
from collections import OrderedDict

class ECGClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        # --- Local Training Loop Start ---
        # (Imagine loading local_train_loader here)
        for epoch in range(3): 
            for images, labels in local_train_loader:
                optimizer.zero_grad()
                criterion(net(images), labels).backward()
                optimizer.step()
        # --- Local Training Loop End ---
        return self.get_parameters(config={}), len(local_train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        # (Imagine loading local_test_loader here)
        loss, correct = 0.0, 0
        with torch.no_grad():
            for images, labels in local_test_loader:
                outputs = net(images)
                loss += criterion(outputs, labels).item()
                correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        accuracy = correct / len(local_test_loader.dataset)
        return float(loss), len(local_test_loader.dataset), {"accuracy": accuracy}
Enter fullscreen mode Exit fullscreen mode

Step 3: Launching the Federated Strategy

On the server side, we use a strategy like FedAvg (Federated Averaging). The server collects weights from multiple clients, averages them, and pushes the new global brain back to the devices.

import flwr as fl

# Define strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    min_fit_clients=2, 
    min_available_clients=2,
)

# Start Flower server
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)
Enter fullscreen mode Exit fullscreen mode

The "Official" Way: Leveling Up Your Privacy Stack 🚀

While this tutorial covers the basic mechanics of Federated Learning using Flower, production environments require more robust handling of differential privacy, secure multi-party computation, and model versioning.

For more production-ready examples and advanced patterns regarding secure data orchestration and AI privacy, I highly recommend checking out the WellAlly Tech Blog. They dive deep into how these technologies are applied in high-stakes industries where data security is non-negotiable.

Conclusion

Federated Learning represents a seismic shift in how we approach Machine Learning. By using Flower and PyTorch, we’ve demonstrated that it is possible to train a high-accuracy ECG diagnostic model without ever compromising patient confidentiality.

Key Takeaways:

  1. Data Sovereignty: The data stays where it's generated.
  2. Reduced Latency: Models can be updated on the edge.
  3. Collaboration: Competitors (or different hospitals) can train shared models for the greater good without sharing trade secrets or private data.

Are you ready to move your training to the edge? Let me know in the comments if you’ve tried implementing FL in your own projects! 🥑💻


Happy coding! If you enjoyed this "Learning in Public" session, don't forget to ❤️ and Bookmark!

Top comments (0)