Introduction
In Part 5, we failed to correctly load weights into our model.
In this part, I will be talking about what we need to write to train our model, as if we were doing it from scratch. Eventually we will get the weights loaded, and this will be less time consuming, but for now, we need a training loop.
The Training Loop
Here is the code I want to be able to write to actually train our model.
let model = Resnet34Model::<2, f32>::build(dev.clone());
// model.download_model();
let mut learner = VisualLearner::builder(dev.clone())
.dataset(dataset)
.model(model)
.build();
learner.train(10)?;
This looks pretty close to the python code that we're probably pretty familiar with at this point.
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2, seed=42,
label_func=is_cat, item_tfms=Resize(224))
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
But since this is Rust, I decided to use the Builder pattern to construct the learner. This allows us to add required arguments, and optional arguments in a type safe and easy to read way.
So how do we go about creating the learner? Let's dig into the code.
I first created a new module in tardyai/src/learners/visual.rs
.
pub struct VisualLearner<'a> {
device: AutoDevice,
dataset: DirectoryImageDataset<'a>,
model: Resnet34Model<2, f32>,
optimizer: Adam<Resnet34Built<2, f32>, f32, AutoDevice>,
}
I'm hard-coding all the model specifics here for now. We can generalize this later. But for now, we have the device
, the dataset
, the model
, and something called an optimizer
. The optimizer is what takes the model, and calculates the changes to the weights that need to happen in order to get closer to the desired output.
In this code we're just going to use the Adam
optimizer, because that is what the fast.ai book does.
So let's look at how we construct one of these. We can't just build it up from scratch in our main function, that wouldn't look good, and we don't want to necessarily construct all the options if we're going to just be using the default. In this particular model, I'm forcing the optimizer to be a default value, but we will be changing that in the future.
impl<'a> VisualLearner<'a> {
pub fn builder(device: AutoDevice) -> builder::Builder<'a, builder::WithoutDataset> {
builder::Builder {
device,
dataset: None,
model: None,
_phantom: PhantomData,
}
}
}
that creates a Builder
with some None
values and some PhantomData
type. Let's look at how Builder
is defined now to see what's going on. I always like to put my builders inside their own internal module, in order to separate out the namespace, and keep the implementation cleaner.
pub mod builder {
use super::*;
use crate::{datasets::DirectoryImageDataset, models::resnet::Resnet34Model};
// 2 ZSTs
pub struct WithoutDataset;
pub struct WithoutModel;
pub struct Ready;
pub struct Builder<'a, T> {
pub(super) device: AutoDevice,
pub(super) dataset: Option<DirectoryImageDataset<'a>>,
pub(super) model: Option<Resnet34Model<2, f32>>,
// 1 PhantomData for typesafe builder pattern
pub(super) _phantom: PhantomData<T>,
}
// 3 The first type of builder we get, can only add dataset
impl<'a> Builder<'a, WithoutDataset> {
pub fn dataset(
self,
dataset: DirectoryImageDataset<'a>
) -> Builder<'a, WithoutModel> {
Builder {
device: self.device,
dataset: Some(dataset),
model: None,
_phantom: PhantomData,
}
}
}
// 4 The second builder we get, still needs a model
impl<'a> Builder<'a, WithoutModel> {
pub fn model(self, model: Resnet34Model<2, f32>) -> Builder<'a, Ready> {
Builder {
device: self.device,
dataset: self.dataset,
model: Some(model),
_phantom: PhantomData,
}
}
}
// 5 Finally construct the `VisualLearner`
impl<'a> Builder<'a, Ready> {
pub fn build(self) -> VisualLearner<'a> {
let model = self.model.unwrap();
VisualLearner::new(self.device, self.dataset.unwrap(), model)
}
}
}
So there is a lot going on here. Why do we have this PhantomData
(1) thing with some weird T
generic type tied to it? Well, that is actually how we can create builders that can't call build()
until they are completely specified. In this case, we want the dataset
and model
to be specified before we allow the user to actually build the VisualLearner
.
That is where those zero sized types(2) come into play. They are used as marker types to keep track of what we're allowed to do with our builder. WithoutDataset
implies that the builder is missing it's dataset
. WithoutModel
clearly means the same about the model
. And we have Ready
which typically would have all the methods to add optional arguments to the builder.
In (3) we are defining the first method for the empty builder. Allowing us to set the dataset
. In (4) we do the same for the model
, and finally in (5) we construct the VisualLearner
itself with a private constructor that we need to write.
Here is that constructor.
impl<'a> VisualLearner<'a> {
// [snip]
fn new(
device: AutoDevice,
dataset: DirectoryImageDataset<'a>,
model: Resnet34Model<2, f32>,
) -> Self {
let adam = Adam::new(&model.model, AdamConfig::default());
Self {
device,
dataset,
model,
optimizer: adam,
}
}
}
This code takes the arguments supplied by the builder, constructs the optimizer, and returns a fully specified VisualLearner
.
Now we've got the learner, how do we train it?
impl<'a> VisualLearner<'a> {
// [snip]
pub fn train(&mut self, epochs: usize) -> Result<(), Error> {
let mut rng = rand::thread_rng();
// 1 Allocate space for gradients in the device
let mut grads = self.model.model.alloc_grads();
for epoch in 0..epochs {
log::info!("Epoch {}", epoch);
for (image, is_cat) in self
.dataset
.shuffled(&mut rng) // 2 Shuffle the dataset randomly each iteration
.map(Result::unwrap)
.map(|(image, is_cat)| { // 3 convert the label into the One-hotted
// representation that these models return
let mut one_hotted = [0.0; 2];
one_hotted[is_cat as usize] = 1.0;
(image, self.device.tensor(one_hotted))
})
.batch_exact(Const::<16>) // 4 Split it into batches of 16 images each
.collate() // 5 Split the tuples into two lists
.stack() // 6 Combine lists to single batch tensors
{
// 8 |--- 8 apply the model to the image batch
// v v- 7 Trace the gradients
let logits = self.model.model.forward_mut(image.traced(grads));
// 9 Calculate the difference between the output and the label
let loss = cross_entropy_with_logits_loss(logits, is_cat);
// 10 Calculate the gradients
grads = loss.backward();
// 11 Apply the optimizer
self.optimizer.update(&mut self.model.model, &grads)?;
// 12 Zero the gradients for next loop
self.model.model.zero_grads(&mut grads);
}
}
Ok(())
}
}
First (1) thing a trainer needs to do is allocate room for the gradients. The gradients are determined by the slope of the many dimensional graph that is constructed by our giant tensor equation, and how large the loss is. So dfdx
is calculating the derivative of all of the operations we're performing on our tensors, and applying the updates backward through the model.
Next (2) we're shuffling the dataset so we aren't over fitting to the same order of models. Then (3) we need to convert the label from a boolean, into a one-hotted tensor where the index for true
or false
(1 or 0) are set to 1, and all the others are zero.
In (4) we split the input into batches of 16 images, to make the training go faster. Then (5) collate the tuples into two side by side lists. And finally (6) stack those lists into tensors that are now 4 dimensional.
Inside our loop we (7) trace our input with our gradients, which just lets dfdx
know that it needs to keep track of the differences so it can calculate the derivative. (8) Apply the model to the batch of images.
The (9) loss function figures out how far off we are from the actual proper label. This is used to (10) calculate the derivatives. We are trying to minimize this as much as possible, and it is what drives our optimizer (11). The final step in all of this is to zero out the gradients so we aren't accumulating previous runs.
Whew, that was a lot, but now we are done, and our code works! Unfortunately for me, my desktop with Cuda support is out of commission while I wait for new parts to arrive, and the laptop I'm using to write this article doesn't have an NVidia card. So until dfdx
can support WebGPU, I'm stuck running it with a CPU Device
, which is remarkably slow. So I can't train this model, but I invite anyone reading this to try it out and let me know how it goes in the comments.
Progress bars and stuff
One thing our library has been missing that fastai
provides by default is an indication of progress. Progress while downloading models, while training, through the epochs. Let's see about adding that to our library.
To do this, we are going to use the indicatif
crate. It provides progress bars, spinners and the like, and easy ways to update them.
For adding a progress bar to our training it's just a two line change. Add the indicatif::ProgressIndicator
trait.
for (image, is_cat) in self
.dataset
.shuffled(&mut rng)
.map(Result::unwrap)
.map(|(image, is_cat)| {
let mut one_hotted = [0.0; 2];
one_hotted[is_cat as usize] = 1.0;
(image, self.device.tensor(one_hotted))
})
.batch_exact(Const::<16>)
.collate()
.stack()
.progress() // <--- This is new
{
// [snip]
}
Adding progress bars to our download is going to be only slightly trickier. In the download_file()
function, we just need this.
// [snip]
log::info!("Downloading {} to: {}", &url, downloaded_file.display());
let mut dest = File::create(&downloaded_file)?;
let pb = indicatif::ProgressBar::new(response.content_length().unwrap_or(0));
let mut buf = [0; 262144]; // 256KiB buffer
while response.read(&mut buf)? > 0 {
dest.write_all(&buf)?;
pb.inc(buf.len() as u64);
}
Ok(downloaded_file)
That creates a progress bar with the length of the response as its length, and updates the progress bar every 256 KiB.
Conclusion
And that's it for this article. I added training to our model, and some fancy progress bars. Tune in next time. I think we can call Chapter 1 done, there was only the one project, and it's time to move on to the next thing.
As always, the code is available on Github or you can fetch the article-6
tag from the repo directly.
git pull
git checkout article-6
Stay tuned, and have a great day!
Top comments (0)