DEV Community

Cover image for Building an AI Model That Predicts Datacenter Network Failure
Dheeraj Ramasahayam
Dheeraj Ramasahayam

Posted on

Building an AI Model That Predicts Datacenter Network Failure

Modern hyperscale datacenters generate enormous volumes of telemetry data every second. Detecting the subtle signals that precede catastrophic network failures is a major challenge for cloud reliability engineering.

In this post, I'll walk through how I engineered a custom Attention-Guided LSTM in PyTorch to predict physical hardware and BGP routing outages up to 60 seconds before they cascade across a datacenter.

The Problem: Network Outages vs. Static Thresholds

Predicting a network failure is notoriously difficult for classical machine learning and static alert systems:

  1. The Need for Sequences: Failures rarely happen instantly. They are the result of compounding micro-faults (soft failures) building up over time. Standard Random Forests evaluate each log in isolation, completely missing the temporal "story" of the degradation.
  2. Static Thresholds Fail: Traditional Z-score limits trigger excessive false positives because datacenter network traffic is inherently bursty.

We needed a model that could read a sliding window of time and understand the chronological context of the telemetry.

Telemetry Data

To ensure the model wasn't just memorizing one specific hardware topology, I validated the architecture against two massive, open-source datasets:

  • Gigabit Optical Failure Dataset: Captures granular microsecond readings of physical light-loss and optical transponder degradation.
  • Cisco BGP Telemetry Dataset: Contains 740,000 continuous milliseconds of abstract Layer-3 packet routing drops and BGP instability.

Handling Class Imbalance

Here lies the hardest part of AIOps: network faults are incredibly rare. In the Cisco dataset, there were only 35,000 failure events buried inside 700,000 healthy logs. If you train a model on this natively, it will lazily predict "Healthy" 100% of the time and still achieve 95% validation accuracy.

To force the neural network to care about the anomalies, I implemented two safeguards in the training pipeline:

1. Synthetic Minority Over-sampling (SMOTE)
I used 2D SMOTE to mathematically inflate the failure sequences in the training batches up to ~20-30%, providing the model with enough examples of what an outage actually looks like.

from imblearn.over_sampling import SMOTE

# Inflate minority failure signatures up to 30% of the majority class
smote = SMOTE(sampling_strategy=0.3, random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train_scaled, y_train_raw)
Enter fullscreen mode Exit fullscreen mode

2. Dynamic Objective Weighting
In PyTorch, I modified the BCEWithLogitsLoss to heavily penalize the model whenever it missed a True Positive failure.

# Calculate the exact ratio of healthy logs to failed logs
num_neg = (y_train == 0).sum()
num_pos = (y_train == 1).sum()
pos_weight_val = num_neg / num_pos if num_pos > 0 else 1.0 

# Aggressively penalize the neural network for missing a failure
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_val]).to(device))
Enter fullscreen mode Exit fullscreen mode

The LSTM Architecture

To capture the temporal degradation, I used a Long Short-Term Memory (LSTM) backbone. Instead of feeding the model single logs, my preprocessing.py script chunks the telemetry into rolling 15-step sequence windows. The LSTM retains the long-term context of the latency spikes across those 15 steps.

The Attention Mechanism

Not every microsecond in a 15-step crash sequence is equally important. To improve the model's accuracy, I layered a Self-Attention Mechanism on top of the LSTM.

Attention allows the neural network to dynamically assign weights to the most critical milliseconds of the sequence—the exact moments the router begins to choke—while ignoring standard background noise.

Here is the entire PyTorch architecture:

import torch
import torch.nn as nn

class AttentionLSTM(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2):
        super(AttentionLSTM, self).__init__()

        # 1. Sequence Processor
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
                            batch_first=True, dropout=0.3)

        # 2. Linear Self-Attention Mechanism
        self.attention = nn.Linear(hidden_size, 1)

        # 3. Dense Classification Output
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Extract the hidden states from the LSTM
        lstm_out, _ = self.lstm(x)

        # Calculate dynamic attention weights for every step in the window
        attn_weights = torch.softmax(self.attention(lstm_out), dim=1)

        # Condense the context vector against the weighted attention matrix
        context = torch.sum(attn_weights * lstm_out, dim=1)

        # Feed-forward through a Sigmoid activation (Binary Failure Alert)
        out = self.fc(context)
        return out
Enter fullscreen mode Exit fullscreen mode

The Training Pipeline

The execution script fires up the DataLoader, shuffles the continuous sequences, and runs the standard Adam optimization loop. Notice how we squeeze the outputs to match the explicit dimensionality required by the BCE Loss function.

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

for epoch in range(12):
    model.train()
    total_loss = 0
    for batch_X, batch_y in train_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        optimizer.zero_grad()

        outputs = model(batch_X).squeeze()

        # Squeeze handling for batch size consistency
        if outputs.ndim == 0: outputs = outputs.unsqueeze(0)
        if batch_y.ndim == 0: batch_y = batch_y.unsqueeze(0)

        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
Enter fullscreen mode Exit fullscreen mode

Results

Because predicting the exact microsecond of a network failure is impossible, I evaluated the model using an actionable 60-Second Window Metric. If the model triggered a critical alert within roughly 60 seconds of a true blackout, it was considered a successful early-warning interception.

Under an ablation study, the standard LSTM achieved a 0.31 F1-Score. By adding the Self-Attention layers, the model jumped to an impressive 0.46 F1-Score on the massively imbalanced 700k-row Cisco dataset, vastly outperforming stateless Random Forests.

Achieving a highly actionable detection profile an entire minute before equipment blackout proves that proactive predictive maintenance is entirely possible with temporal sequence learning. With 60 seconds of warning, automated scripts can effortlessly drain optical queues and re-route BGP traffic, shifting site reliability from reactive incident management to autonomous, zero-downtime awareness.

GitHub Repository

You can review the full open-source Python implementation, reproduce the training environment, and read the formally formatted IEEE research paper on my GitHub repository here:

🔗 datacenter-network-failure-predictor

Top comments (0)