This post is part of our Privacy-Preserving Machine Learning with AIJack series.
- Part 1: Federated Learning
- Part 2: Model Inversion Attack against Federated Learning
- Part 3: Federated Learning with Homomorphic Encryption
- Part 4: Federated Learning with Differential Privacy
- Part 5: Federated Learning with Sparse Gradient
- Part 6: Poisoning Attack against Federated Learning
- Part 7: Federated Learning with FoolsGold
- Part 8: Split Learning
- Part 9: Label Leakage against Split Learning
Overview
In this tutorial, we will learn the novel distributed learning algorithm, Federated Learning, which allows you to train a neural network while preserving privacy.
While deep learning achieves substantial success in various areas, training deep learning models require much data. Thus, acquiring high performance in deep learning while preserving privacy is challenging. One way to solve this problem is Federated Learning, where multiple clients collaboratively train a single global model without sharing their local dataset.
The procedure of typical Federated Learning is as follows:
1. The central server initializes the global model.
2. The server distributes global model to each client.
3. Each client locally calculates the gradient of the loss function on their dataset.
4. Each client sends the gradient to the server.
5. The server aggregates the received gradients with some method (e.g., average) and updates the global model with the aggregated gradient.
6. Repeat 2 ~ 5 until converge.
The mathematical notification when the aggregation is the weighted average is as follows:
, where is the parameter of the global model in -th round, is the gradient calculated on -th client's dataset , is the number of -th client's dataset, and N is the total number of samples.
Code
Next, we will implement FedAVG [1], one of the most representative methods of Federated Learning. We use AIJack, an OSS, to simulate machine learning algorithms' security and privacy risks. AIJack supports both single-process and MPI as its backend.
First, we install AIJack with pip
.
apt install -y libboost-all-dev
pip install -U pip
pip install "pybind11[global]"
pip install aijack
Single-process
We import the following modules.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI
The hyper-parameters are as follows.
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
client_size = 2
criterion = F.nll_loss
This tutorial uses the MNIST dataset.
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
AIJack allows you to implement the clients and server of Federated Learning with PyTorch model.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]
server = FedAVGServer(clients, Net().to(device))
Then, you can execute the training via run
method of FedAVGAPI
.
api = FedAVGAPI(
server,
clients,
criterion,
local_optimizers,
local_dataloaders,
num_communication=num_rounds,
custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
MPI
You can easily convert the above code to MPI-compatible code that can run in the parallel programming environment.
# mpi_FedAVG.py
import random
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
logger = getLogger(__name__)
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
def evaluate_gloal_model(dataloader):
def _evaluate_global_model(api):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(api.device), target.to(api.device)
output = api.party(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(dataloader.dataset)
accuracy = 100.0 * correct / len(dataloader.dataset)
print(
f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
)
return _evaluate_global_model
def main():
fix_seed(seed)
comm = MPI.COMM_WORLD
myid = comm.Get_rank()
size = comm.Get_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)
mpi_client_manager = MPIFedAVGClientManager()
mpi_server_manager = MPIFedAVGServerManager()
MPIFedAVGClient = mpi_client_manager.attach(FedAVGClient)
MPIFedAVGServer = mpi_server_manager.attach(FedAVGServer)
if myid == 0:
dataloader = prepare_dataloader(size - 1, myid, train=False)
client_ids = list(range(1, size))
server = MPIFedAVGServer(comm, [1, 2], model)
api = MPIFedAVGAPI(
comm,
server,
True,
F.nll_loss,
None,
None,
num_rounds,
1,
custom_action=evaluate_gloal_model(dataloader),
device=device
)
else:
dataloader = prepare_dataloader(size - 1, myid, train=True)
client = MPIFedAVGClient(comm, model, user_id=myid)
api = MPIFedAVGAPI(
comm,
client,
False,
F.nll_loss,
optimizer,
dataloader,
num_rounds,
1,
device=device
)
api.run()
if __name__ == "__main__":
main()
You can run the above code with the standard MPI command.
!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py
Summary
In this tutorial, we learned Federated Learning, one promising approach to securely train deep learning models without violating privacy. You can find more examples and notebooks in the document of AIJack. Although this scheme seems safe since each client does not have to share its local dataset, the next tutorial demonstrates that shared local gradients might leak private information.
Reference
[1] McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017.
Top comments (0)