## Introduction

In Part 4, we managed to construct the model, and download the weights from Hugging Face.

In this part, I think I've figured out an ugly way to load those weights into the model. Let's see if we can do some classification with these weights and see if it gives something sensible at all.

## Some cleanup first

So I'm going to admit that without the changes that are coming to `dfdx`

in the github `main`

branch, this is going to be remarkably hacky.

But before we get to that, I did some cleanup. While reading the fast.ai book, I stumbled across some terms that we were using incorrectly in our model. So in the interest of being true to the book, I've renamed `Head`

and `Tail`

to `Stem`

and `Head`

. In particular the *head* of the model is actually the last bit that performs the classification tasks. I figured it might be confusing to people who are reading this with a more ML background.

```
type Stem = (
Conv2D<3, 64, 7, 2, 3>,
BatchNorm2D<64>,
ReLU,
MaxPool2D<3, 2, 1>,
);
pub type Head<const NUM_CLASSES: usize> = (AvgPoolGlobal, Linear<512, NUM_CLASSES>);
```

I've also split `Resnet18`

and `Resnet34`

into two parts. The body, and a `Head`

, which we've previously discussed. This will hopefully make it easier to do transfer learning and change the shape of the head later down the road. I've only copied the `Resnet34`

version here, since it's what we're working on.

```
// Layer clusters are in groups of [3, 4, 6, 4]
pub type Resnet34Body = (
Stem,
(
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
),
(
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
(BasicBlock<128>, ReLU, BasicBlock<128>, ReLU),
),
(
(
Downsample<128, 256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
(
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
),
(
Downsample<256, 512>,
ReLU,
BasicBlock<512>,
ReLU,
BasicBlock<512>,
ReLU,
),
);
pub type Resnet34<const NUM_CLASSES: usize> = (Resnet34Body, Head<NUM_CLASSES>);
```

There, that looks better.

## What about those weights?

Now we can start getting to the ugly stuff. The interface for visiting each of the tensors in a model, and updating them doesn't allow for a tree-like structure. It just iterates through the tensors one at a time, and passes the object you tell it to to your handler.

So I came up with the naive idea to just create a giant array of all of the names in the weights file, and manually sort them into the order that will be iterated over from the model we created. It's really not pretty, but it "works". (The scare quotes are there because it almost certainly is wrong somehow, and we'll have to wait until we try to classify something to see if these are sane values.)

I'm not going to paste the whole array, because it's > 200 lines, and that feels like a giant waste of space. If you want to go look at it, feel free to do so here.

```
pub(crate) const RESNET34_LAYERS: [&'static str; 254] = [
"resnet.embedder.embedder.convolution.weight",
"resnet.embedder.embedder.normalization.weight",
"resnet.embedder.embedder.normalization.bias",
"resnet.embedder.embedder.normalization.running_mean",
"resnet.embedder.embedder.normalization.running_var",
"resnet.embedder.embedder.normalization.num_batches_tracked",
"resnet.embedder.embedder.normalization.num_batches_tracked",
// [SNIP]
];
```

This snippet includes a part that I wanted to point out. The `dfdx`

`BatchNorm2D`

type contains two scalar values, `epsilon`

and `momentum`

, but the safetensors file downloaded from Hugging Face does not. It does have one scalar per normalization layer called `num_batches_tracked`

. So in the interest of expediency, I just duplicated this value. This is most certainly not the correct thing to do, but the values I saw during my debugging process were all zeros, so I'm hoping it doesn't matter in the long run.

Now let's see about the class that takes that and applies the tensors to the model.

```
pub(crate) struct NamedTensorVisitor<'a> {
names: &'a '[&'static str],
idx: usize,
tensors: &'a SafeTensors<'a>,
}
impl<'a> NamedTensorVisitor<'a> {
pub(crate) fn new(names: &'a [&'static str], tensors: &'a SafeTensors<'a>) -> Self {
Self {
names,
idx: 0,
tensors,
}
}
}
```

So this class is dead simple, just increase the index every time you visit a tensor or scalar, and use the name that is at that index, fetching the weights from the `SafeTensors`

field.

Implementing the `TensorVisitor`

trait requires us to add the `num-traits`

crate to our package, because we need to reference the `NumCast`

trait in the definition.

```
impl<'a, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for NamedTensorVisitor<'a> {
type Viewer = ViewTensorMut;
type Err = Error;
type E2 = E;
type D2 = D;
fn visit<S: Shape>(
&mut self,
_opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
log::debug!(
"Loading tensor shape: {:?}, {:?}",
t.shape(),
&self.names.get(self.idx)
);
t.load_safetensor(
&mut self.tensors,
&self.names.get(self.idx).ok_or(Error::NotEnoughNames)?,
)?;
self.idx += 1;
Ok(None)
}
fn visit_scalar<N: NumCast>(
&mut self,
_opts: ScalarOptions<N>,
n: <Self::Viewer as TensorViewer>::View<'_, N>,
) -> Result<Option<N>, Self::Err> {
log::debug!("Loading scalar: {:?}", &self.names.get(self.idx));
let tensor = self
.tensors
.tensor(self.names.get(self.idx).ok_or(Error::NotEnoughNames)?)?;
let data = tensor.data();
let mut array = [0; 8];
array.copy_from_slice(data);
let val = f64::from_le_bytes(array);
*n = N::from(val).ok_or(Error::NumberFormatException)?;
self.idx += 1;
Ok(None)
}
}
```

This trait is a little complicated. We'll start by going over the associated types.

`Viewer`

there is just assigned to `ViewTensorMut`

which is a marker type that means we will receive a `&mut Tensor<>`

in the `visit`

method.

`Err`

is just our crate's `Error`

enumeration.

`E2`

and `D2`

are there in case you want to make changes to the tensors, either to the datatype or moving it to a different device. Since we aren't returning anything, this could be any number of things, but I'll just keep it set to the same as the inputs.

`visit()`

is pretty straightforward, just load the tensor from the current name, and advance the index by one.

`visit_scalar()`

was an interesting one, and I'm copying from the `dfdx`

crate's implementation. Scalars in `SafeTensors`

objects are stored as `f64`

in little-endian format, but the only way we can get to the data is by copying it into a `[u8; 8]`

, and converting that to an `f64`

. Finally, we store it in the scalar by converting it to `N`

(`f32`

in our case) from the `f64`

, and advance the index by one.

Now we just have to actually use this struct to load up the tensors in the `download_models`

method we created in the previous part.

```
pub fn download_model(&mut self) -> Result<(), Error> {
log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
let model_file = download_model(ModelUrl::Resnet34)?;
// TODO: Somehow make something like this work
// self.model.load_safetensors(&model_file)?;
let file = File::open(model_file).unwrap();
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
let tensors = SafeTensors::deserialize(&buffer).unwrap();
let _ = <<Resnet34<N> as BuildOnDevice<AutoDevice, E>>::Built as TensorCollection<
E,
AutoDevice,
>>::iter_tensors(&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
})?;
Ok(())
}
```

That complicated generic type after we load the tensors from the file really just boils down to interpretting the `Built`

type, that we store as our model instance, as a `TensorCollection<>`

, and calling the `iter_tensors()`

associated function on that type.

`iter_tensors()`

takes a `ModuleVisitor`

. I'm using the included `RecursiveWalker`

that just recurses through the tensors in the model, and calls `visit()`

and `visit_scalar()`

on them.

Actually, let's go ahead and clean that complicated generic type up a bit.

```
type Resnet34Built<const NUM_CLASSES: usize, E> =
<Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built;
```

This will let us clean up our model type definition.

```
pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
E: Dtype + SafeDtype,
Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
AutoDevice: Device<E>,
{
pub model: Resnet34Built<NUM_CLASSES, E>,
}
```

And then actually clean up the ugly `iter_tensors()`

call

```
let _ = <Resnet34Built<N, E> as TensorCollection<E, AutoDevice>>::iter_tensors(
&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
},
)?;
```

That looks better. And we finally have the weights loaded into the model!

## What can we do with it?

Let's try to use the model as it is to classify some of our `Pets`

images, and see if the tensor it returns correlates to the `cat`

or `dog`

classes in the ImageNet dataset collection.

To apply the model to an image just means we call `forward()`

on the tensor from the image. So let's do just that.

```
model.download_model()?;
let (image, is_cat) = dataset.get(1)?;
let categories = model.model.forward(image);
log::info!("Is Cat? {}", is_cat);
log::trace!("Categories: {:#?}", categories.array());
let max_category = categories
.softmax()
.array()
.into_iter()
.map(ordered_float::OrderedFloat)
.enumerate()
.max_by_key(|t| t.1)
.unwrap();
log::info!("(Category, Weight): {:?}", max_category);
```

This fetches the 2nd image in our dataset which, for me at least, is a cat. Then we apply the model to the tensor.

I've added a log statement for the categories tensor that is returned. This is a `Rank1<1000>`

dimensional tensor that is supposed to have the largest value in the index that corresponds to the category of object in the input image.

So let's figure out which that is. I apply `softmax()`

to the categories, which takes the ugly output of the original `categories`

tensor, and massages the values so that they all add up to 1, while keeping the largest value the same. Unfortunately, `f32`

in rust doesn't implement the `Ord`

trait, so we can't just fetch the maximum value from the tensor directly, we have to convert it to a type that does implement `Ord`

. That is where the `ordered-float`

crate comes into play. As long as we don't have any infinities or `NaN`

s in our data, we can trust a total ordering.

And now if we run this we get `(89, OrderedFloat(0.0099451095))`

. Hmm, if the weight of the category is supposed to correlate to how certain the model is, 0.9% doesn't inspire me with confidence. Let's check what that category is supposed to represent.

I found this json file that is supposed to be the categories that the original ResNet models were trained on. If we look in there for index 89, we get... `sulphur-crested_cockatoo`

...

Well, that's a little frustrating, I don't think that is a cat. Something tells me that our model may not have been set up correctly. I'll attempt to get it working between articles. But this has taken a long time, and we may as well just train the model from scratch to see if we even have anything worth saving.

I'm going to plan on the next article talking about using the model in a learner with an optimizer like the penultimate line of the python code we're targetting.

```
learn = vision_learner(dls, resnet34, metrics=error_rate)
```

If I get this model working with transfer learning in the meantime, I'll switch it up and talk about how I managed it.

Anyway, until then, the code is available on github as always and if you're following along in your own editor, you can checkout the `article-5`

tag.

```
git pull origin
git checkout article-5
```

Stay tuned, and I hope to see you in the next article with a learner in tow.

## Top comments (0)