Introduction
In the previous parts of this series, we built:
Part 1: The Tensor class and computation graph
Part 2: Automatic differentiation from scratch
Part 3: A simple neural network trained on MNIST
In Part 3, we manually defined each weight in our SimpleNN class. This works for small networks, but imagine building a 50-layer model you'd have to manually track every single parameter!
In this post, we’ll build the foundation of nn.Module, a system to organize layers, manage parameters, and support training and evaluation modes. This is the core of every modern deep learning library.
Missed Part 1?
- Read it here: https://dev.to/zekcrates/lets-build-a-deep-learning-library-from-scratch-using-numpy-part-1-32p9
Want to skip the series and read the full book now?
- Read it for free online: https://zekcrates.quarto.pub/deep-learning-library/
Parameter class
A Parameter is just a Tensor that is marked as learnable. This makes it easy to distinguish weights from intermediate tensors.
For the Tensor class the default value of requires_grad=False.
from babygrad import Tensor
class Parameter(Tensor):
def __init__(self, data, *args, **kwargs):
kwargs['requires_grad'] = True
super().__init__(data, *args, **kwargs)
# Example
a = Tensor([1, 2, 3])
print(a.requires_grad) # False
b = Parameter(a)
print(b.requires_grad) # True
Whenever you see self.weight = Parameter(...), you immediately know it’s a learnable parameter.
Finding Parameters
Now that we have the Parameter class it would be nice if we can get all the parameters of a model.
A model might store parameters in attributes, lists, or dictionaries. To collect them automatically, we define _get_parameters().
def _get_parameters(data):
params = []
if isinstance(data, Parameter):
return [data]
if isinstance(data, dict):
for value in data.values():
params.extend(_get_parameters(value))
if isinstance(data, (list, tuple)):
for item in data:
params.extend(_get_parameters(item))
return params
This helper method will be used in the Module class to get all the parameters.
Module Base Class
Every layer (Linear,ReLu,BatchNorm) needs to:
- Manage parameters: Find all weights inside itself
- Define a forward pass: Process input data
- Track training state: Know if it's training or evaluating
from typing import List
class Module:
def __init__(self):
self.training = True
def parameters(self) -> List[Parameter]:
params = _get_parameters(self.__dict__)
unique_params = []
seen_ids = set()
for p in params:
if id(p) not in seen_ids:
unique_params.append(p)
seen_ids.add(id(p))
return unique_params
def forward(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Now whenever we have defined a model we can just use model.parameters().
We also need to toggle self.training. So we will use a helper method _get_modules() that will find all the modules present in the model and then toggle self.training.
def _get_modules(obj) -> list['Module']:
modules = []
if isinstance(obj, Module):
return [obj]
if isinstance(obj, dict):
for value in obj.values():
modules.extend(_get_modules(value))
if isinstance(obj, (list, tuple)):
for item in obj:
modules.extend(_get_modules(item))
return modules
class Module:
# code
def train(self):
self.training = True
for m in _get_modules(self.__dict__):
m.training = True
def eval(self):
self.training = False
for m in _get_modules(self.__dict__):
m.training = False
We also added new methods (train,eval) inside the Module class.
Now that we have our base class done we can finally create some decent layers .
Stateless Layers: ReLU, Sigmoid, Tanh, Flatten
Some layers don’t have learnable parameters. They just apply a function to the input.
class ReLU(Module):
def forward(self, x):
return ops.relu(x)
class Sigmoid(Module):
def forward(self, x):
return ops.sigmoid(x)
class Tanh(Module):
def forward(self, x):
return ops.tanh(x)
class Flatten(Module):
def forward(self, x):
batch_size = x.shape[0]
return x.reshape(batch_size, -1)
NOTE: ops.somefunction(x) was covered in PART 2.
Linear Layer
This layer is the most basic layer that can do a lot of magic.
It is stateful and has a weight and bias that we need to learn.
class Linear(Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
dtype: str = "float32"):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(Tensor.randn(in_features, out_features))
self.bias = None
if bias:
self.bias = Parameter(Tensor.zeros(1, out_features))
def forward(self, x: Tensor) -> Tensor:
# (bs,in) @ (in,out) -> (bs,out)
out = x @ self.weight
if self.bias is not None:
# (1,out) -> (bs,out) #broadcasted
out += self.bias.broadcast_to(out.shape)
return out
Sequential: Stacking Layers
If we have multiple modules, it would be a complex task to call forward on each of the modules
class MyModel(Module):
def __init__(self):
super().__init__()
self.w1 = Linear(10, 20)
self.w2 = Linear(20, 30)
self.relu = ReLU()
self.final = Linear(30, 10)
def forward(self, x):
x = self.w1(x)
x = self.relu(x)
x = self.w2(x)
x = self.relu(x)
x = self.final(x)
return x
The Sequential solves this problem by chaining modules together automatically.
The output of one layer becomes the input to the next.
class Sequential(Module):
def __init__(self, *modules):
super().__init__()
self.modules = modules
def forward(self, x):
for m in self.modules:
x = m(x)
return x
Now we can simply do
model = Sequential(
Linear(10, 20),
ReLU(),
Linear(20, 30),
ReLU(),
Linear(30, 10)
)
logits = model(x)
MSE Loss
class MSELoss(Module):
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
"""
Calculates the Mean Squared Error.
"""
diff = pred - target
sq_diff = diff * diff
return sq_diff.sum() / Tensor(target.data.size)
Conclusion
In this post, we built the core nn.Module abstraction that lets us define layers, manage parameters automatically, and compose models cleanly.
With this foundation in place, we can now focus on training instead of bookkeeping.
In the next post, we’ll implement optimizers and use them to train models built with nn.Module
More Layers (BatchNorm,LayerNorm,Dropout) are covered in the book!
- Read it for free online: https://zekcrates.quarto.pub/deep-learning-library/
Top comments (0)