DEV Community

Cover image for Spiking Neural Networks
Emirhan Akdeniz for SerpApi

Posted on • Originally published at serpapi.com

Spiking Neural Networks

In this blog post, we will discuss the differences between spiking neural networks, and non-spiking neural networks, potential use cases of these algorithms, and open source a simple example script to compare a simple SNN model to an ANN model.

At SerpApi Blog, we discuss different topics around web scraping:

Register to Claim Free Credits

What is a spiking neural network?

A spiking neural network (SNN) is a type of artificial neural network that more closely mimics the behavior of biological neurons. Unlike traditional artificial neural networks (ANNs), which use continuous activation functions, SNNs use discrete events called spikes. These spikes represent the times at which neurons fire and enable the network to process information in a way that is more similar to synapses of the brain. This enabling mechanism is what makes SNNs superior to ANNs in inference times for temporal (in-time steps) data.

What is the difference between spiking and non-spiking neural networks?

The primary difference between spiking and non-spiking neural networks lies in how they handle information processing:

  • Spiking neural networks: Use spikes to represent information, with firing neurons only when their membrane potential reaches a certain threshold. These neuronal modelings are often associated with neuromorphic computing and use learning rules such as spike-timing-dependent plasticity (STDP) instead of backpropagation.

  • Non-spiking neural networks (e.g. ANNs): Use continuous activation functions like ReLU or sigmoid to process information and typically use backpropagation for learning.

How Does Spiking Neural Network Work?

SNNs work by simulating the behavior of biological neurons. When a neuron’s membrane potential exceeds a certain threshold, it generates a spike that propagates to other spiking neurons. These spikes can modify synaptic weights through learning rules such as spike-timing-dependent plasticity (STDP), enabling the network to learn from temporal patterns in the data.

What are the most effective ways to use neural networks for pattern recognition?

For pattern recognition, deep learning models like convolutional neural network models (CNNs) are highly effective. However, SNNs are gaining attention for their ability to recognize spatiotemporal (belonging to both space and time or to space-time) patterns with high precision and low power consumption.

What could it mean for the future of scraping?

In my humble opinion, SNNs could hold a place in finding patterns within changing and evolving HTML structures. Instead of classifying items and parsing them, SNNs may be useful for identifying where specific parts of the HTML are within the overall body. This could reduce human interaction and pave the way for the future of fully automated parsers with higher precision and lower inference times.

SNN vs ANN Comparison

The following is a simple demonstration script for comparing SNN and ANN models under the same conditions. I have to give a disclaimer that this is for demonstration purposes, and not proving purposes, or definite benchmarking. As I repeat time and again, I am not an expert in machine learning, I am just an enthusiast.

Let's import the libraries. We will be using PyTorch for the framework, sklearn for simple dataset-splitting tasks, and snntorch for creating SNN models in PyTorch:


import numpy as np

import torch

import torch.nn as nn

import torch.optim as optim

import snntorch as snn

from sklearn.model_selection import train_test_split

import time

Enter fullscreen mode Exit fullscreen mode

Let's create a function that simulates motion data (a kind of temporal data):


def generate_motion_data(num_samples, event_length, num_events, noise_level):
    X = []
    y = []
    for _ in range(num_samples):
        motion_indices = np.random.randint(0, event_length, size=num_events)
        event_data = np.zeros(event_length)
        event_data[motion_indices] = 1
        noise = np.random.normal(0, noise_level, size=event_length)
        event_data += noise

        # Introduce variability in the patterns
        if np.random.rand() < 0.5:
            event_data = np.roll(event_data, np.random.randint(1, event_length))

        X.append(event_data)
        y.append(1 if np.sum(event_data) > 0 else 0)
    return np.array(X), np.array(y)

Enter fullscreen mode Exit fullscreen mode

The output of this encoding will be a binary where 1 representing motion and 0 representing no motion. We also introduce a Gaussian noise with standard deviation to make the data more consistent with real-world data. Alongside noise, we introduce some random variability patterns to make the task harder. The model should be able to take into consideration all of these factors and predict the motion output within the series.

Let's create our data:


# Parameters

num_samples = 1000

event_length = 100

num_events = 100

noise_level = 0.1

# Generate data

X, y = generate_motion_data(num_samples, event_length, num_events, noise_level)

# Convert to PyTorch tensors

X = torch.tensor(X, dtype=torch.float32)

y = torch.tensor(y, dtype=torch.float32)

# Split into training and validation sets

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

Enter fullscreen mode Exit fullscreen mode

Let's define an SNN model and train it:


# Define SNN model
class SpikingNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SpikingNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        mem1, spk1 = self.lif1(x)
        x = self.fc2(spk1)
        return x

# Model, loss function, and optimizer for SNN
input_dim = event_length
hidden_dim = 64
output_dim = 1  # Binary classification

snn_model = SpikingNN(input_dim, hidden_dim, output_dim)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
optimizer = optim.Adam(snn_model.parameters(), lr=0.001)

# Training loop for SNN
num_epochs = 100
snn_training_start = time.time()

for epoch in range(num_epochs):
    snn_model.train()
    optimizer.zero_grad()
    outputs = snn_model(X_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

    # Calculate training loss
    train_loss = loss.item()

    # Validation
    snn_model.eval()
    with torch.no_grad():
        val_outputs = snn_model(X_val)
        val_loss = criterion(val_outputs.squeeze(), y_val)
        val_loss = val_loss.item()

    print(f'SNN Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

snn_training_time = time.time() - snn_training_start
print(f"SNN Training Time: {snn_training_time:.4f} seconds")

Enter fullscreen mode Exit fullscreen mode

Let's create an ANN model and train it for comparison:


# Define ANN model

class ANN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Model, loss function, and optimizer for ANN
ann_model = ANN(input_dim, hidden_dim, output_dim)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
optimizer = optim.Adam(ann_model.parameters(), lr=0.001)

# Training loop for ANN
ann_training_start = time.time()

for epoch in range(num_epochs):
    ann_model.train()
    optimizer.zero_grad()
    outputs = ann_model(X_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

    # Calculate training loss
    train_loss = loss.item()

    # Validation
    ann_model.eval()
    with torch.no_grad():
        val_outputs = ann_model(X_val)
        val_loss = criterion(val_outputs.squeeze(), y_val)
        val_loss = val_loss.item()

    print(f'ANN Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

ann_training_time = time.time() - ann_training_start
print(f"ANN Training Time: {ann_training_time:.4f} seconds")

Enter fullscreen mode Exit fullscreen mode

Let's define a function to run predictions, calculate the inference time, and compare the two models:


# Function to predict and measure inference time

def predict_and_measure_time(model, new_data):
    start_time = time.time()
    model.eval()
    with torch.no_grad():
        new_data_tensor = torch.tensor(new_data, dtype=torch.float32)
        outputs = model(new_data_tensor)
    inference_time = time.time() - start_time
    return outputs, inference_time

# Generate new test data
X_test, y_test = generate_motion_data(5, event_length, num_events, noise_level)

# Predictions with SNN
snn_outputs, snn_inference_time = predict_and_measure_time(snn_model, X_test)
snn_predictions = torch.round(torch.sigmoid(snn_outputs)).squeeze().numpy()
print("SNN Predictions:", snn_predictions)
print(f"SNN Inference Time: {snn_inference_time:.4f} seconds")

# Predictions with ANN
ann_outputs, ann_inference_time = predict_and_measure_time(ann_model, X_test)
ann_predictions = torch.round(torch.sigmoid(ann_outputs)).squeeze().numpy()
print("ANN Predictions:", ann_predictions)
print(f"ANN Inference Time: {ann_inference_time:.4f} seconds")

# Comparison Summary
print(f"Comparison Summary:")
print(f"SNN Training Time: {snn_training_time:.4f} seconds")
print(f"ANN Training Time: {ann_training_time:.4f} seconds")
print(f"SNN Inference Time: {snn_inference_time:.4f} seconds")
print(f"ANN Inference Time: {ann_inference_time:.4f} seconds")

# Final validation accuracies (from the last epoch)
snn_model.eval()
with torch.no_grad():
    snn_val_outputs = snn_model(X_val)
    snn_val_accuracy = ((torch.sigmoid(snn_val_outputs) > 0.5).squeeze().float() == y_val).float().mean().item()

ann_model.eval()
with torch.no_grad():
    ann_val_outputs = ann_model(X_val)
    ann_val_accuracy = ((torch.sigmoid(ann_val_outputs) > 0.5).squeeze().float() == y_val).float().mean().item()

print(f"Final SNN Validation Accuracy: {snn_val_accuracy:.4f}")
print(f"Final ANN Validation Accuracy: {ann_val_accuracy:.4f}")

Enter fullscreen mode Exit fullscreen mode

Following numbers are not definitive as the task is simple. But it should give a clear idea as to how SNN can beat ANN:

Comparison Summary:

SNN Training Time: 0.6785 seconds

ANN Training Time: 0.3952 seconds

SNN Inference Time: 0.0007 seconds

ANN Inference Time: 0.0017 seconds

Final SNN Validation Accuracy: 1.0000

Final ANN Validation Accuracy: 1.0000

It took more time to train SNN due to framework and the architecture not in favor of training the model. However in the future we may see different approaches to optimize training time of SNN models. For inference time, SNN was faster than ANN model. This is where the energy efficiency is taking place. Because SNN is much easier to execute under same accuracy, it is consuming less CPU power.

Are spiking neural networks the future?

Spiking neural networks (SNNs) are increasingly being considered as a promising frontier in the future of artificial intelligence, particularly due to their closer resemblance to the neural computation seen in biological brains. Leveraging principles from neuroscience, SNNs process information through spikes, or action potentials, which offers a unique form of temporal coding and rate coding.

In contrast to traditional deep neural networks, which rely on gradient descent and back propagation, SNNs utilize spike-based learning algorithms and synaptic plasticity, making them more efficient in certain types of neural computation. The initialization of these spiking networks involves setting up multi-layer architectures capable of handling the temporal dynamics and correlations within spike trains.

One of the significant advantages of SNNs is their potential for lower energy consumption, especially when implemented on neuromorphic hardware. These processors mimic the brain’s architecture and function, enabling real-time processing with minimal latency. This is particularly beneficial in applications like robotics, computer vision, and large-scale network models, where real-time and efficient computations are crucial.

SNNs also offer improved interpretability compared to traditional deep-learning models. Each single neuron in an SNN can be examined for its specific role in the network, which aids in understanding how neural computations propagate through the system. Feedforward and recurrent neural networks can both be implemented within the SNN framework, providing versatility in handling different types of data and tasks.

Despite these advantages, SNNs face challenges in terms of learning algorithms and network models. The nonlinear nature of spike-based communication and the need for precise temporal synchronization complicate the development of effective supervised learning techniques. Additionally, the number of spikes and their timing (latency) play a crucial role in the plausibility and performance of SNNs.

Recent advances in state-of-the-art neuromorphic processors and spiking neuron models show promise for overcoming these hurdles. As research in neuroscience and artificial intelligence continues to converge, SNNs may become more viable for practical applications, enhancing the capabilities of both AI and computational neuroscience.

In summary, spiking neural networks hold significant potential for the future of AI, particularly in areas requiring efficient, real-time processing with low energy consumption. Their biologically inspired approach offers a plausible and powerful alternative to traditional deep learning, potentially revolutionizing fields such as robotics, computer vision, and beyond.

Conclusion

I’m grateful to the reader for their attention. I hope this blog post will help people understand the potential of spiking neural networks and their use cases. The reader may find the full comparison script below:

Full script:


!pip install torch snntorch scikit-learn

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
from sklearn.model_selection import train_test_split  
import time

# Generate synthetic event-based motion data
def generate_motion_data(num_samples, event_length, num_events, noise_level):
    X = []
    y = []
    for _ in range(num_samples):
        motion_indices = np.random.randint(0, event_length, size=num_events)
        event_data = np.zeros(event_length)
        event_data[motion_indices] = 1
        noise = np.random.normal(0, noise_level, size=event_length)
        event_data += noise

        # Introduce variability in the patterns
        if np.random.rand() < 0.5:
            event_data = np.roll(event_data, np.random.randint(1, event_length))

        X.append(event_data)
        y.append(1 if np.sum(event_data) > 0 else 0)
    return np.array(X), np.array(y)

# Parameters
num_samples = 1000
event_length = 100
num_events = 100
noise_level = 0.1

# Generate data
X, y = generate_motion_data(num_samples, event_length, num_events, noise_level)

# Convert to PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Define SNN model
class SpikingNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SpikingNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        mem1, spk1 = self.lif1(x)
        x = self.fc2(spk1)
        return x

# Model, loss function, and optimizer for SNN
input_dim = event_length
hidden_dim = 64
output_dim = 1  # Binary classification

snn_model = SpikingNN(input_dim, hidden_dim, output_dim)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
optimizer = optim.Adam(snn_model.parameters(), lr=0.001)

# Training loop for SNN
num_epochs = 100
snn_training_start = time.time()

for epoch in range(num_epochs):
    snn_model.train()
    optimizer.zero_grad()
    outputs = snn_model(X_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

    # Calculate training loss
    train_loss = loss.item()

    # Validation
    snn_model.eval()
    with torch.no_grad():
        val_outputs = snn_model(X_val)
        val_loss = criterion(val_outputs.squeeze(), y_val)
        val_loss = val_loss.item()

    print(f'SNN Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

snn_training_time = time.time() - snn_training_start
print(f"SNN Training Time: {snn_training_time:.4f} seconds")

# Define ANN model
class ANN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Model, loss function, and optimizer for ANN
ann_model = ANN(input_dim, hidden_dim, output_dim)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
optimizer = optim.Adam(ann_model.parameters(), lr=0.001)

# Training loop for ANN
ann_training_start = time.time()

for epoch in range(num_epochs):
    ann_model.train()
    optimizer.zero_grad()
    outputs = ann_model(X_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

    # Calculate training loss
    train_loss = loss.item()

    # Validation
    ann_model.eval()
    with torch.no_grad():
        val_outputs = ann_model(X_val)
        val_loss = criterion(val_outputs.squeeze(), y_val)
        val_loss = val_loss.item()

    print(f'ANN Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

ann_training_time = time.time() - ann_training_start
print(f"ANN Training Time: {ann_training_time:.4f} seconds")

# Function to predict and measure inference time
def predict_and_measure_time(model, new_data):
    start_time = time.time()
    model.eval()
    with torch.no_grad():
        new_data_tensor = torch.tensor(new_data, dtype=torch.float32)
        outputs = model(new_data_tensor)
    inference_time = time.time() - start_time
    return outputs, inference_time

# Generate new test data
X_test, y_test = generate_motion_data(5, event_length, num_events, noise_level)

# Predictions with SNN
snn_outputs, snn_inference_time = predict_and_measure_time(snn_model, X_test)
snn_predictions = torch.round(torch.sigmoid(snn_outputs)).squeeze().numpy()
print("SNN Predictions:", snn_predictions)
print(f"SNN Inference Time: {snn_inference_time:.4f} seconds")

# Predictions with ANN
ann_outputs, ann_inference_time = predict_and_measure_time(ann_model, X_test)
ann_predictions = torch.round(torch.sigmoid(ann_outputs)).squeeze().numpy()
print("ANN Predictions:", ann_predictions)
print(f"ANN Inference Time: {ann_inference_time:.4f} seconds")

# Comparison Summary
print(f"Comparison Summary:")
print(f"SNN Training Time: {snn_training_time:.4f} seconds")
print(f"ANN Training Time: {ann_training_time:.4f} seconds")
print(f"SNN Inference Time: {snn_inference_time:.4f} seconds")
print(f"ANN Inference Time: {ann_inference_time:.4f} seconds")

# Final validation accuracies (from the last epoch)
snn_model.eval()
with torch.no_grad():
    snn_val_outputs = snn_model(X_val)
    snn_val_accuracy = ((torch.sigmoid(snn_val_outputs) > 0.5).squeeze().float() == y_val).float().mean().item()

ann_model.eval()
with torch.no_grad():
    ann_val_outputs = ann_model(X_val)
    ann_val_accuracy = ((torch.sigmoid(ann_val_outputs) > 0.5).squeeze().float() == y_val).float().mean().item()

print(f"Final SNN Validation Accuracy: {snn_val_accuracy:.4f}")
print(f"Final ANN Validation Accuracy: {ann_val_accuracy:.4f}")

Enter fullscreen mode Exit fullscreen mode

Top comments (0)