DEV Community

Favil Orbedios
Favil Orbedios

Posted on

Working through the fast.ai book in Rust - Part 5

Introduction

In Part 4, we managed to construct the model, and download the weights from Hugging Face.

In this part, I think I've figured out an ugly way to load those weights into the model. Let's see if we can do some classification with these weights and see if it gives something sensible at all.

Some cleanup first

So I'm going to admit that without the changes that are coming to dfdx in the github main branch, this is going to be remarkably hacky.

But before we get to that, I did some cleanup. While reading the fast.ai book, I stumbled across some terms that we were using incorrectly in our model. So in the interest of being true to the book, I've renamed Head and Tail to Stem and Head. In particular the head of the model is actually the last bit that performs the classification tasks. I figured it might be confusing to people who are reading this with a more ML background.

type Stem = (
    Conv2D<3, 64, 7, 2, 3>,
    BatchNorm2D<64>,
    ReLU,
    MaxPool2D<3, 2, 1>,
);

pub type Head<const NUM_CLASSES: usize> = (AvgPoolGlobal, Linear<512, NUM_CLASSES>);
Enter fullscreen mode Exit fullscreen mode

I've also split Resnet18 and Resnet34 into two parts. The body, and a Head, which we've previously discussed. This will hopefully make it easier to do transfer learning and change the shape of the head later down the road. I've only copied the Resnet34 version here, since it's what we're working on.

// Layer clusters are in groups of [3, 4, 6, 4]
pub type Resnet34Body = (
    Stem,
    (
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
    ),
    (
        (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
        (BasicBlock<128>, ReLU, BasicBlock<128>, ReLU),
    ),
    (
        (
            Downsample<128, 256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
        ),
        (
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
        ),
    ),
    (
        Downsample<256, 512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
    ),
);

pub type Resnet34<const NUM_CLASSES: usize> = (Resnet34Body, Head<NUM_CLASSES>);
Enter fullscreen mode Exit fullscreen mode

There, that looks better.

What about those weights?

Now we can start getting to the ugly stuff. The interface for visiting each of the tensors in a model, and updating them doesn't allow for a tree-like structure. It just iterates through the tensors one at a time, and passes the object you tell it to to your handler.

So I came up with the naive idea to just create a giant array of all of the names in the weights file, and manually sort them into the order that will be iterated over from the model we created. It's really not pretty, but it "works". (The scare quotes are there because it almost certainly is wrong somehow, and we'll have to wait until we try to classify something to see if these are sane values.)

I'm not going to paste the whole array, because it's > 200 lines, and that feels like a giant waste of space. If you want to go look at it, feel free to do so here.

pub(crate) const RESNET34_LAYERS: [&'static str; 254] = [
        "resnet.embedder.embedder.convolution.weight",
        "resnet.embedder.embedder.normalization.weight",
        "resnet.embedder.embedder.normalization.bias",
        "resnet.embedder.embedder.normalization.running_mean",
        "resnet.embedder.embedder.normalization.running_var",
        "resnet.embedder.embedder.normalization.num_batches_tracked",
        "resnet.embedder.embedder.normalization.num_batches_tracked",
        // [SNIP]
];
Enter fullscreen mode Exit fullscreen mode

This snippet includes a part that I wanted to point out. The dfdx BatchNorm2D type contains two scalar values, epsilon and momentum, but the safetensors file downloaded from Hugging Face does not. It does have one scalar per normalization layer called num_batches_tracked. So in the interest of expediency, I just duplicated this value. This is most certainly not the correct thing to do, but the values I saw during my debugging process were all zeros, so I'm hoping it doesn't matter in the long run.

Now let's see about the class that takes that and applies the tensors to the model.

pub(crate) struct NamedTensorVisitor<'a> {
    names: &'a '[&'static str],
    idx: usize,
    tensors: &'a SafeTensors<'a>,
}

impl<'a> NamedTensorVisitor<'a> {
    pub(crate) fn new(names: &'a [&'static str], tensors: &'a SafeTensors<'a>) -> Self {
        Self {
            names,
            idx: 0,
            tensors,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

So this class is dead simple, just increase the index every time you visit a tensor or scalar, and use the name that is at that index, fetching the weights from the SafeTensors field.

Implementing the TensorVisitor trait requires us to add the num-traits crate to our package, because we need to reference the NumCast trait in the definition.

impl<'a, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for NamedTensorVisitor<'a> {
    type Viewer = ViewTensorMut;
    type Err = Error;
    type E2 = E;
    type D2 = D;

    fn visit<S: Shape>(
        &mut self,
        _opts: TensorOptions<S, E, D>,
        t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
    ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
        log::debug!(
            "Loading tensor shape: {:?}, {:?}",
            t.shape(),
            &self.names.get(self.idx)
        );
        t.load_safetensor(
            &mut self.tensors,
            &self.names.get(self.idx).ok_or(Error::NotEnoughNames)?,
        )?;
        self.idx += 1;
        Ok(None)
    }

    fn visit_scalar<N: NumCast>(
        &mut self,
        _opts: ScalarOptions<N>,
        n: <Self::Viewer as TensorViewer>::View<'_, N>,
    ) -> Result<Option<N>, Self::Err> {
        log::debug!("Loading scalar: {:?}", &self.names.get(self.idx));
        let tensor = self
            .tensors
            .tensor(self.names.get(self.idx).ok_or(Error::NotEnoughNames)?)?;
        let data = tensor.data();
        let mut array = [0; 8];
        array.copy_from_slice(data);
        let val = f64::from_le_bytes(array);
        *n = N::from(val).ok_or(Error::NumberFormatException)?;

        self.idx += 1;
        Ok(None)
    }
}
Enter fullscreen mode Exit fullscreen mode

This trait is a little complicated. We'll start by going over the associated types.

Viewer there is just assigned to ViewTensorMut which is a marker type that means we will receive a &mut Tensor<> in the visit method.

Err is just our crate's Error enumeration.

E2 and D2 are there in case you want to make changes to the tensors, either to the datatype or moving it to a different device. Since we aren't returning anything, this could be any number of things, but I'll just keep it set to the same as the inputs.

visit() is pretty straightforward, just load the tensor from the current name, and advance the index by one.

visit_scalar() was an interesting one, and I'm copying from the dfdx crate's implementation. Scalars in SafeTensors objects are stored as f64 in little-endian format, but the only way we can get to the data is by copying it into a [u8; 8], and converting that to an f64. Finally, we store it in the scalar by converting it to N(f32 in our case) from the f64, and advance the index by one.

Now we just have to actually use this struct to load up the tensors in the download_models method we created in the previous part.

    pub fn download_model(&mut self) -> Result<(), Error> {
        log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
        let model_file = download_model(ModelUrl::Resnet34)?;

        // TODO: Somehow make something like this work
        // self.model.load_safetensors(&model_file)?;

        let file = File::open(model_file).unwrap();
        let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
        let tensors = SafeTensors::deserialize(&buffer).unwrap();

        let _ = <<Resnet34<N> as BuildOnDevice<AutoDevice, E>>::Built as TensorCollection<
            E,
            AutoDevice,
        >>::iter_tensors(&mut RecursiveWalker {
            m: &mut self.model,
            f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
        })?;

        Ok(())
    }
Enter fullscreen mode Exit fullscreen mode

That complicated generic type after we load the tensors from the file really just boils down to interpretting the Built type, that we store as our model instance, as a TensorCollection<>, and calling the iter_tensors() associated function on that type.

iter_tensors() takes a ModuleVisitor. I'm using the included RecursiveWalker that just recurses through the tensors in the model, and calls visit() and visit_scalar() on them.

Actually, let's go ahead and clean that complicated generic type up a bit.

type Resnet34Built<const NUM_CLASSES: usize, E> =
    <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built;
Enter fullscreen mode Exit fullscreen mode

This will let us clean up our model type definition.

pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
    E: Dtype + SafeDtype,
    Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
    AutoDevice: Device<E>,
{
    pub model: Resnet34Built<NUM_CLASSES, E>,
}
Enter fullscreen mode Exit fullscreen mode

And then actually clean up the ugly iter_tensors() call

    let _ = <Resnet34Built<N, E> as TensorCollection<E, AutoDevice>>::iter_tensors(
        &mut RecursiveWalker {
            m: &mut self.model,
            f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
        },
    )?;
Enter fullscreen mode Exit fullscreen mode

That looks better. And we finally have the weights loaded into the model!

What can we do with it?

Let's try to use the model as it is to classify some of our Pets images, and see if the tensor it returns correlates to the cat or dog classes in the ImageNet dataset collection.

To apply the model to an image just means we call forward() on the tensor from the image. So let's do just that.

    model.download_model()?;

    let (image, is_cat) = dataset.get(1)?;
    let categories = model.model.forward(image);

    log::info!("Is Cat? {}", is_cat);
    log::trace!("Categories: {:#?}", categories.array());

    let max_category = categories
        .softmax()
        .array()
        .into_iter()
        .map(ordered_float::OrderedFloat)
        .enumerate()
        .max_by_key(|t| t.1)
        .unwrap();
    log::info!("(Category, Weight): {:?}", max_category);
Enter fullscreen mode Exit fullscreen mode

This fetches the 2nd image in our dataset which, for me at least, is a cat. Then we apply the model to the tensor.

I've added a log statement for the categories tensor that is returned. This is a Rank1<1000> dimensional tensor that is supposed to have the largest value in the index that corresponds to the category of object in the input image.

So let's figure out which that is. I apply softmax() to the categories, which takes the ugly output of the original categories tensor, and massages the values so that they all add up to 1, while keeping the largest value the same. Unfortunately, f32 in rust doesn't implement the Ord trait, so we can't just fetch the maximum value from the tensor directly, we have to convert it to a type that does implement Ord. That is where the ordered-float crate comes into play. As long as we don't have any infinities or NaNs in our data, we can trust a total ordering.

And now if we run this we get (89, OrderedFloat(0.0099451095)). Hmm, if the weight of the category is supposed to correlate to how certain the model is, 0.9% doesn't inspire me with confidence. Let's check what that category is supposed to represent.

I found this json file that is supposed to be the categories that the original ResNet models were trained on. If we look in there for index 89, we get... sulphur-crested_cockatoo...

Well, that's a little frustrating, I don't think that is a cat. Something tells me that our model may not have been set up correctly. I'll attempt to get it working between articles. But this has taken a long time, and we may as well just train the model from scratch to see if we even have anything worth saving.

I'm going to plan on the next article talking about using the model in a learner with an optimizer like the penultimate line of the python code we're targetting.

learn = vision_learner(dls, resnet34, metrics=error_rate)
Enter fullscreen mode Exit fullscreen mode

If I get this model working with transfer learning in the meantime, I'll switch it up and talk about how I managed it.

Anyway, until then, the code is available on github as always and if you're following along in your own editor, you can checkout the article-5 tag.

git pull origin
git checkout article-5
Enter fullscreen mode Exit fullscreen mode

Stay tuned, and I hope to see you in the next article with a learner in tow.

Top comments (0)