DEV Community

Favil Orbedios
Favil Orbedios

Posted on

Working through the fast.ai book in Rust - Part 4

Introduction

In Part 3, we covered creating Tensors of our images, and loading them up into the device we're performing our matrix multiplication on.

In this part, I want to go over actually constructing the ResNet-34 model that the fast.ai book uses.

This is more difficult than it seems, because in addition to loading the weights from the internet, we also need to figure out how to cut off the last few layers, and add on some new layers that will give us our new categories.

The original model was trained on 1000 different categories, but in this first chapter, we're just determining if it is a cat or not. And the output Tensor is size 2.

So join me on this journey of discovery.

Let's build the model

First off, we need to enable support for convolutions. Convolutions take a group of pixels, in like a square around a central pixel, and convolve them into a single value. This is a fun relatively deep concept in math. I found this 3blue1brown video to be very approachable and entertaining.

Convolutions in dfdx are, unfortunately, only available on the nightly rust compiler. So first things first, let's enable the nightly channel.

[toolchain]
channel = "nightly"
Enter fullscreen mode Exit fullscreen mode

Now I'm going to cheat a little. The lovely creators of dfdx left an example of the ResNet-18 structure in their repo. This isn't exactly what I need for ResNet-34, but it's pretty close, and it gives us a great jumping-off point.

So let's get started by stealing that code. I'm going to make a simple change and add a new type definition for the Tail of the structure. I plan to use this to change the shape of the final layers for re-training the model on different classes.

use dfdx::prelude::*;

type BasicBlock<const C: usize> = Residual<(
    Conv2D<C, C, 3, 1, 1>,
    BatchNorm2D<C>,
    ReLU,
    Conv2D<C, C, 3, 1, 1>,
    BatchNorm2D<C>,
)>;

type Downsample<const C: usize, const D: usize> = GeneralizedResidual<
    (
        Conv2D<C, D, 3, 2, 1>,
        BatchNorm2D<D>,
        ReLU,
        Conv2D<D, D, 3, 1, 1>,
        BatchNorm2D<D>,
    ),
    (Conv2D<C, D, 1, 2, 0>, BatchNorm2D<D>),
>;

type Head = (
    Conv2D<3, 64, 7, 2, 3>,
    BatchNorm2D<64>,
    ReLU,
    MaxPool2D<3, 2, 1>,
);

pub type Tail<const NUM_CLASSES: usize> = (AvgPoolGlobal, Linear<512, NUM_CLASSES>);

pub type Resnet18<const NUM_CLASSES: usize> = (
    Head,
    (BasicBlock<64>, ReLU, BasicBlock<64>, ReLU),
    (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
    (Downsample<128, 256>, ReLU, BasicBlock<256>, ReLU),
    (Downsample<256, 512>, ReLU, BasicBlock<512>, ReLU),
    Tail<NUM_CLASSES>,
);
Enter fullscreen mode Exit fullscreen mode

Now, if you're anything like me, this is a bunch of incomprehensible nonsense. But luckily, we don't have to worry about why all these things exist, we just need to understand the basic structure.

For understanding what everything does, I found that Chapter 14 of the fast.ai book contains a wealth of information about the exact structure of what goes into a ResNet model. But that doesn't help us much right now, we're only on Chapter 1!

So in my search for how to understand the vague reasons for the large parts of this, and in particular, how to change this to ResNet-34, and even larger models; I went searching for the weights for the ResNet models, and I landed on the Hugging Face website. This is an awesome resource, and it definitely has the weights we need, but it also had a diagram of the structure of the layers of this model in particular. In the model card of the ResNet-34 model, I found this diagram.

ResNet-34 model diagram

Well, I was able to connect the dots, and I noticed the Residual words in the ResNet-18 types. The top diagram in that image is almost describing the BasicBlock and Downsample parts of this code. The 4 tuples between Head and Tail line up almost perfectly with those colorful blocks.

BasicBlock diagram
Downsample diagram

So BasicBlock corresponds to the first image. It's a pair of convolution layers, evidently separated by a ReLU layer, whatever that is. And Downsample must correspond to the second image with the dashed line. It looks like it takes the 64s to 128s, and that looks like what this line is doing.

    (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
Enter fullscreen mode Exit fullscreen mode

Now, the difference seems to be just the number of BasicBlocks. ResNet-18 has 4 groups with 2 blocks in each group. But ResNet-34 seems to have 4 groups, with 3 in the first, then 4, then 6, then finally 3.

I still wasn't convinced, and I didn't trust my basic counting skills, so I went searching for how the fastai library defines the resnet34 model. That lead me to this code in the pytorch library. And sure enough, there are those numbers again!

def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)

def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
Enter fullscreen mode Exit fullscreen mode

4 groups of 2, and [3, 4, 6, 3]. It even says BasicBlock! So it looks like the dfdx creators knew what they were doing. Who'd have thought?

So, it seems I just need to modify the structure of those middle tuples to have 3, 4, 6, and 3 layers. Let's see what we can do.

// Layer clusters are in groups of [3, 4, 6, 4]
pub type Resnet34<const NUM_CLASSES: usize> = (
    Head,
    (
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
    ),
    (
        Downsample<64, 128>,
        ReLU,
        BasicBlock<128>,
        ReLU,
        BasicBlock<128>,
        ReLU,
        BasicBlock<128>,
        ReLU,
    ),
    (
        Downsample<128, 256>,
        ReLU,
        BasicBlock<256>,
        ReLU,
        BasicBlock<256>,
        ReLU,
        BasicBlock<256>,
        ReLU,
        BasicBlock<256>,
        ReLU,
        BasicBlock<256>,
        ReLU,
    ),
    (
        Downsample<256, 512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
    ),
    Tail<NUM_CLASSES>,
);

Enter fullscreen mode Exit fullscreen mode

I can't tell, but this looks reasonable. Let's call it good for now.

You may have noticed that I included the ResNet-50 definition earlier. I wanted that there to point out that I have no idea what Bottleneck does, and at this point, I think it's fair to say it doesn't matter for Chapter 1, we only need the ResNet-34 model.

What can we do about the weights

Now Hugging Face has the weights we need. I've been reading up and it looks like the safetensors format is the fastest and safest weight format, and it just so happens to be supported by dfdx, we just need to enable the safetensors feature flag. Easy enough. Let's modify the top level Cargo.toml file.

dfdx = { version = "0.13", features = ["safetensors"] }
Enter fullscreen mode Exit fullscreen mode

Now I've realized that I called my Url enum by a silly name. What kind of Url? Well let's change the name to DatasetUrl since it is used for downloading the dataset data. I just did a "Rename variable" operation in my IDE, and it took care of all the locations it was used for me.

So let's create a new enum in tardy/src/download.rs

const HF_BASE: &'static str = "https://huggingface.co/";

#[derive(Debug, Clone, Copy)]
pub(crate) enum ModelUrl {
    Resnet18,
    Resnet34,
}

impl ModelUrl {
    pub(crate) fn url(self) -> String {
        match self {
            ModelUrl::Resnet18 => {
                format!("{HF_BASE}microsoft/resnet-18/resolve/main/model.safetensors?download=true")
            }
            ModelUrl::Resnet34 => {
                format!("{HF_BASE}microsoft/resnet-34/resolve/main/model.safetensors?download=true")
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

And now I want to create a new wrapper type that can hold onto a model, and provide methods for it, like download_model(). This will allow ease of use at the original call site. The format I'm aiming for is the following.

    let mut model = Resnet34Model::<1000, f32>::build(dev);
    model.download_model()?;
Enter fullscreen mode Exit fullscreen mode

Now dfdx seems to have a very complicated type system, so this is going to look really awful. I'll go over the worst bits.

pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
    E: Dtype,
    Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
    AutoDevice: Device<E>,
{
    model: <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built,
}

impl<E, const N: usize> Resnet34Model<N, E>
where
    E: Dtype,
    AutoDevice: Device<E>,
    Resnet34<N>: BuildOnDevice<AutoDevice, E>,
{
    pub fn build(dev: AutoDevice) -> Self {
        let model = dev.build_module::<Resnet34<N>, E>();
        Self { model }
    }
}
Enter fullscreen mode Exit fullscreen mode

The most important part here is the line:

    model: <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built,
Enter fullscreen mode Exit fullscreen mode

This line is the difference between the type that we specified, and the concrete type that gets built by the AutoDevice::build_module() method. We can't just store a field with type Resnet34<NUM_CLASSES>, that isn't usable directly. In particular, it doesn't have any notion of the datatype that will be used for the model, whether it is f32 or f64. So we have to specify that Resnet34<NUM_CLASSES> implements the BuildOnDevice<> trait, and use the associated type, BuildOnDevice::Built.

We ensure that our type does implement BuildOnDevice with this line:

    Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
Enter fullscreen mode Exit fullscreen mode

We next need to ensure that the device supports the datatype we are using with the line:

    AutoDevice: Device<E>,
Enter fullscreen mode Exit fullscreen mode

The build() method now just takes in the Device we create in main.rs, and constructs the model with the AutoDevice::build_module method.

First source of frustration

Now, this looks like it should work. We've got a model that is relatively concise, and looks very similar to the ResNet-18 model.

So, why when I build this do I get the following horrendous error message?

error[E0277]: the trait bound `((dfdx::prelude::Conv2D<3, 64, 7, 2, 3>, 
dfdx::prelude::BatchNorm2D<64>, ReLU, MaxPool2D<3, 2, 1>), 
(dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64, 64, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<64>)>, ReLU, dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64, 
64, 3, 1, 1>, dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<64>)>, ReLU, dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64, 
64, 3, 1, 1>, dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<64>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<64, 
128, 3, 2, 1>, dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 
1>, dfdx::prelude::BatchNorm2D<128>), (dfdx::prelude::Conv2D<64, 128, 1, 2>, 
dfdx::prelude::BatchNorm2D<128>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<128>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<128, 
256, 3, 2, 1>, dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 
1>, dfdx::prelude::BatchNorm2D<256>), (dfdx::prelude::Conv2D<128, 256, 1, 2>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<256>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<256, 
512, 3, 2, 1>, dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1, 
1>, dfdx::prelude::BatchNorm2D<512>), (dfdx::prelude::Conv2D<256, 512, 1, 2>, 
dfdx::prelude::BatchNorm2D<512>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<512, 512, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<512>)>, ReLU, 
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<512, 512, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1, 1>, 
dfdx::prelude::BatchNorm2D<512>)>, ReLU), (AvgPoolGlobal, dfdx::prelude::Linear<512, _>)): 
BuildOnDevice<Cpu, _>` is not satisfied
Enter fullscreen mode Exit fullscreen mode

Well our first clue to figure this monstrosity out comes from the little section at the end.

BuildOnDevice<Cpu, _>` is not satisfied
Enter fullscreen mode Exit fullscreen mode

So, that means our giant model type, Resnet34 which just so happens to expand out to that awful tuple above, doesn't implement BuildOnDevice. Well, in the previous section we just stated that we needed it to do just that.

The next clue comes a few lines down from that

    = help: the following other types implement trait `BuildOnDevice<D, E>`:
              ()
              (M1,)
              (M1, M2)
              (M1, M2, M3)
              (M1, M2, M3, M4)
              (M1, M2, M3, M4, M5)
              (M1, M2, M3, M4, M5, M6)
Enter fullscreen mode Exit fullscreen mode

Ahah! Evidently tuples of Models only implement BuildOnDevice for varieties up to 6-tuples. We have an 8-tuple, and a 12-tuple! So it looks like we just need to split our too large tuples down into smaller bite sized pieces.

Lets do just that. Here is the new model now.

pub type Resnet34<const NUM_CLASSES: usize> = (
    Head,
    (
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
        BasicBlock<64>,
        ReLU,
    ),
    (
        // tuples are only supported with up to 6 items in `dfdx`
        (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
        (BasicBlock<128>, ReLU, BasicBlock<128>, ReLU),
    ),
    (
        // tuples are only supported with up to 6 items in `dfdx`
        (
            Downsample<128, 256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
        ),
        (
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
            BasicBlock<256>,
            ReLU,
        ),
    ),
    (
        Downsample<256, 512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
        BasicBlock<512>,
        ReLU,
    ),
    Tail<NUM_CLASSES>,
);
Enter fullscreen mode Exit fullscreen mode

And that builds! Excellent it was a simple fix.

Moving on to downloading models

Now that we have a model that can be concretely represented, and the code builds, we need to add some code to download the model files from Hugging Face.

I'm going to go ahead and refactor tardyai/src/download.rs while I'm here, so we can reuse our old download logic.

// v--- refactor out this logic, so it's shorter in the other functions.
fn get_home_dir() -> Result<PathBuf, Error> {
    let home = homedir::get_my_home()?
        .expect("home directory needs to exist")
        .join(".tardyai");
    Ok(home)
}

pub fn untar_images(url: DatasetUrl) -> Result<PathBuf, Error> {
    let home = get_home_dir()?;
    let dest_dir = home.join("archive");
    ensure_dir(&dest_dir)?;
    let archive_file = download_file(url.url(), &dest_dir, None)?;

    let dest_dir = home.join("data");
    let dir = extract_archive(&archive_file, &dest_dir)?;

    Ok(dir)
}

// v--- Add a crate public function that will download from a `ModelUrl`
pub(crate) fn download_model(url: ModelUrl) -> Result<PathBuf, Error> {
    let home = get_home_dir()?;
    let dest_dir = home.join("models");
    ensure_dir(&dest_dir)?;
    let model_file = download_file(url.url(), &dest_dir, Some(&format!("{url:?}.safetensors")))?;
    Ok(model_file)
}

// v--- Change the name to something more generic
fn download_file(
    url: String,
    dest_dir: &Path,
    // v--- This was needed because the filenames we download from Hugging Face 
    //      are pretty ugly looking strings of hex digits.
    default_name: Option<&str>,
) -> Result<PathBuf, Error> {
    let mut response = reqwest::blocking::get(&url)?;
    let file_name = default_name
        .or(response.url().path_segments().and_then(|s| s.last()))
        .and_then(|name| if name.is_empty() { None } else { Some(name) })
        //     v--- Add a new `Error` variant
        .ok_or(Error::DownloadNameNotSpecified(url.clone()))?;

    let downloaded_file = dest_dir.join(file_name);

    // TODO: check if the archive is valid and exists
    if downloaded_file.exists() {
        log::info!("File already exists: {}", downloaded_file.display());
        return Ok(downloaded_file);
    }

    log::info!("Downloading {} to: {}", &url, downloaded_file.display());
    let mut dest = File::create(&downloaded_file)?;
    response.copy_to(&mut dest)?;
    Ok(downloaded_file)
}
Enter fullscreen mode Exit fullscreen mode

And that's done, so let's create the download_models method on our concrete model type. We can even make it call load_safetensors() while we're at it.

impl<E, const N: usize> Resnet34Model<N, E>
where
    E: Dtype + dfdx::tensor::safetensors::SafeDtype,
    AutoDevice: Device<E>,
    Resnet34<N>: BuildOnDevice<AutoDevice, E>,
{
    // ...

    pub fn download_model(&mut self) -> Result<(), Error> {
        log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
        let model_file = download_model(ModelUrl::Resnet34)?;
        self.model.load_safetensors(&model_file)?;
        Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

So that's it, we're done!

Not so fast

Ah, it builds fine now, but I'm now getting this error when it runs.

➜   cargo run
   Compiling tardyai v0.1.0 (/home/klah/git/articles/fastai-rust/tardyai/tardyai)
   Compiling chapter1 v0.1.0 (/home/klah/git/articles/fastai-rust/tardyai/chapter1)
    Finished dev [unoptimized + debuginfo] target(s) in 4.37s
     Running `target/debug/chapter1`
[2023-11-22T00:01:42Z INFO  tardyai::download] File already exists: /home/klah/.tardyai/archive/oxford-iiit-pet.tgz
[2023-11-22T00:01:42Z INFO  tardyai::download] Extracting archive /home/klah/.tardyai/archive/oxford-iiit-pet.tgz to: /home/klah/.tardyai/data
[2023-11-22T00:01:42Z INFO  tardyai::download] Archive already extracted to: /home/klah/.tardyai/data/oxford-iiit-pet/
[2023-11-22T00:01:42Z INFO  chapter1] Images are in: /home/klah/.tardyai/data/oxford-iiit-pet/images
[2023-11-22T00:01:42Z INFO  chapter1] Found 7390 files
[2023-11-22T00:01:48Z INFO  tardyai::models::resnet] Downloading model from https://huggingface.co/microsoft/resnet-34/resolve/main/model.safetensors?download=true
[2023-11-22T00:01:49Z INFO  tardyai::download] Downloading https://huggingface.co/microsoft/resnet-34/resolve/main/model.safetensors?download=true to: /home/klah/.tardyai/models/Resnet34.safetensors
Error:
   0: Error with safetensors file: SafeTensorError(TensorNotFound("0.0.weight"))

Location:
   chapter1/src/main.rs:38

Backtrace omitted. Run with RUST_BACKTRACE=1 environment variable to display it.
Run with RUST_BACKTRACE=full to include source snippets.
Enter fullscreen mode Exit fullscreen mode

Well, dang it. Evidently safetensors files have names for the layers that they are storing the weights for. I guess I'm going to have to figure out what this file actually contains, and load them individually into our model.

dfdx supports safetensors with the safetensors crate. So I'll add that dependency and lets get to debugging.

This page of the safetensors docs mentions using the memmap2 crate, so I'll go ahead and add that as well.

    pub fn download_model(&mut self) -> Result<(), Error> {
        log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
        let model_file = download_model(ModelUrl::Resnet34)?;
        // self.model.load_safetensors(&model_file)?;

        let file = File::open(model_file).unwrap();
        let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
        let tensors = SafeTensors::deserialize(&buffer).unwrap();
        let mut names = tensors.tensors();
        names.sort_by_key(|t| t.0.clone());
        for (name, tensor) in names {
            log::info!("Name: {name}: {:?}", tensor.shape());
        }
        Ok(())
    }
Enter fullscreen mode Exit fullscreen mode

Running this code gives us the following output.

[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: classifier.1.bias: [1000]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: classifier.1.weight: [1000, 512]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.convolution.weight: [64, 3, 7, 7]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.weight: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.convolution.weight: [64, 64, 3, 3]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.weight: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.convolution.weight: [64, 64, 3, 3]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO  tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.weight: [64]
...
Enter fullscreen mode Exit fullscreen mode

I've cut off most of the output, because it is an awful lot.

Well the shape of resnet.embedder.embedder.confolution.weight is very similar to the convolution in Head:

type Head = (
    Conv2D<3, 64, 7, 2, 3>,
    BatchNorm2D<64>,
    ReLU,
    MaxPool2D<3, 2, 1>,
);
Enter fullscreen mode Exit fullscreen mode

So I bet they correlate, and this all looks fairly structured. stages runs from 0 to 3, so there are 4 stages, like we have 4 tuples of BasicBlocks.

Stage 2 has the first layer key running up to 5, so it has 6 layers, which is the number of BasicBlocks in the third tuple.

So I think we have the pattern pretty mapped out to what tensor goes where. Now we need to figure out how to actually load the tensors into the various weights, ideally without specifying each tuple entry manually.

This will probably be done with some form of TensorVisitor. But this article is getting pretty long, so let's save that for next time.

Conclusion

In Part 4, we were able to construct the model, and download the weights in the form of a safetensors file from Hugging Face. But we ran into an issue with actually loading the weights into the model, because they weren't named the same as what dfdx expects. Check out the code for this part at github. Or check out the article-4 tag.

git co article-4
Enter fullscreen mode Exit fullscreen mode

Stay tuned for Part 5 where we figure out how to solve this conundrum.

Top comments (0)