Introduction
In Part 3, we covered creating Tensors of our images, and loading them up into the device we're performing our matrix multiplication on.
In this part, I want to go over actually constructing the ResNet-34 model that the fast.ai book uses.
This is more difficult than it seems, because in addition to loading the weights from the internet, we also need to figure out how to cut off the last few layers, and add on some new layers that will give us our new categories.
The original model was trained on 1000 different categories, but in this first chapter, we're just determining if it is a cat or not. And the output Tensor is size 2.
So join me on this journey of discovery.
Let's build the model
First off, we need to enable support for convolutions. Convolutions take a group of pixels, in like a square around a central pixel, and convolve them into a single value. This is a fun relatively deep concept in math. I found this 3blue1brown video to be very approachable and entertaining.
Convolutions in dfdx
are, unfortunately, only available on the nightly rust compiler. So first things first, let's enable the nightly
channel.
[toolchain]
channel = "nightly"
Now I'm going to cheat a little. The lovely creators of dfdx
left an example of the ResNet-18 structure in their repo. This isn't exactly what I need for ResNet-34, but it's pretty close, and it gives us a great jumping-off point.
So let's get started by stealing that code. I'm going to make a simple change and add a new type
definition for the Tail
of the structure. I plan to use this to change the shape of the final layers for re-training the model on different classes.
use dfdx::prelude::*;
type BasicBlock<const C: usize> = Residual<(
Conv2D<C, C, 3, 1, 1>,
BatchNorm2D<C>,
ReLU,
Conv2D<C, C, 3, 1, 1>,
BatchNorm2D<C>,
)>;
type Downsample<const C: usize, const D: usize> = GeneralizedResidual<
(
Conv2D<C, D, 3, 2, 1>,
BatchNorm2D<D>,
ReLU,
Conv2D<D, D, 3, 1, 1>,
BatchNorm2D<D>,
),
(Conv2D<C, D, 1, 2, 0>, BatchNorm2D<D>),
>;
type Head = (
Conv2D<3, 64, 7, 2, 3>,
BatchNorm2D<64>,
ReLU,
MaxPool2D<3, 2, 1>,
);
pub type Tail<const NUM_CLASSES: usize> = (AvgPoolGlobal, Linear<512, NUM_CLASSES>);
pub type Resnet18<const NUM_CLASSES: usize> = (
Head,
(BasicBlock<64>, ReLU, BasicBlock<64>, ReLU),
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
(Downsample<128, 256>, ReLU, BasicBlock<256>, ReLU),
(Downsample<256, 512>, ReLU, BasicBlock<512>, ReLU),
Tail<NUM_CLASSES>,
);
Now, if you're anything like me, this is a bunch of incomprehensible nonsense. But luckily, we don't have to worry about why all these things exist, we just need to understand the basic structure.
For understanding what everything does, I found that Chapter 14 of the fast.ai book contains a wealth of information about the exact structure of what goes into a ResNet model. But that doesn't help us much right now, we're only on Chapter 1!
So in my search for how to understand the vague reasons for the large parts of this, and in particular, how to change this to ResNet-34, and even larger models; I went searching for the weights for the ResNet models, and I landed on the Hugging Face website. This is an awesome resource, and it definitely has the weights we need, but it also had a diagram of the structure of the layers of this model in particular. In the model card of the ResNet-34 model, I found this diagram.
Well, I was able to connect the dots, and I noticed the Residual
words in the ResNet-18 types. The top diagram in that image is almost describing the BasicBlock
and Downsample
parts of this code. The 4 tuples between Head
and Tail
line up almost perfectly with those colorful blocks.
So BasicBlock
corresponds to the first image. It's a pair of convolution layers, evidently separated by a ReLU
layer, whatever that is. And Downsample
must correspond to the second image with the dashed line. It looks like it takes the 64
s to 128
s, and that looks like what this line is doing.
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
Now, the difference seems to be just the number of BasicBlock
s. ResNet-18 has 4 groups with 2 blocks in each group. But ResNet-34 seems to have 4 groups, with 3 in the first, then 4, then 6, then finally 3.
I still wasn't convinced, and I didn't trust my basic counting skills, so I went searching for how the fastai
library defines the resnet34
model. That lead me to this code in the pytorch
library. And sure enough, there are those numbers again!
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
4 groups of 2, and [3, 4, 6, 3]
. It even says BasicBlock
! So it looks like the dfdx
creators knew what they were doing. Who'd have thought?
So, it seems I just need to modify the structure of those middle tuples to have 3, 4, 6, and 3 layers. Let's see what we can do.
// Layer clusters are in groups of [3, 4, 6, 4]
pub type Resnet34<const NUM_CLASSES: usize> = (
Head,
(
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
),
(
Downsample<64, 128>,
ReLU,
BasicBlock<128>,
ReLU,
BasicBlock<128>,
ReLU,
BasicBlock<128>,
ReLU,
),
(
Downsample<128, 256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
(
Downsample<256, 512>,
ReLU,
BasicBlock<512>,
ReLU,
BasicBlock<512>,
ReLU,
),
Tail<NUM_CLASSES>,
);
I can't tell, but this looks reasonable. Let's call it good for now.
You may have noticed that I included the ResNet-50 definition earlier. I wanted that there to point out that I have no idea what Bottleneck
does, and at this point, I think it's fair to say it doesn't matter for Chapter 1, we only need the ResNet-34 model.
What can we do about the weights
Now Hugging Face has the weights we need. I've been reading up and it looks like the safetensors
format is the fastest and safest weight format, and it just so happens to be supported by dfdx
, we just need to enable the safetensors
feature flag. Easy enough. Let's modify the top level Cargo.toml
file.
dfdx = { version = "0.13", features = ["safetensors"] }
Now I've realized that I called my Url
enum by a silly name. What kind of Url? Well let's change the name to DatasetUrl
since it is used for downloading the dataset data. I just did a "Rename variable" operation in my IDE, and it took care of all the locations it was used for me.
So let's create a new enum in tardy/src/download.rs
const HF_BASE: &'static str = "https://huggingface.co/";
#[derive(Debug, Clone, Copy)]
pub(crate) enum ModelUrl {
Resnet18,
Resnet34,
}
impl ModelUrl {
pub(crate) fn url(self) -> String {
match self {
ModelUrl::Resnet18 => {
format!("{HF_BASE}microsoft/resnet-18/resolve/main/model.safetensors?download=true")
}
ModelUrl::Resnet34 => {
format!("{HF_BASE}microsoft/resnet-34/resolve/main/model.safetensors?download=true")
}
}
}
}
And now I want to create a new wrapper type that can hold onto a model, and provide methods for it, like download_model()
. This will allow ease of use at the original call site. The format I'm aiming for is the following.
let mut model = Resnet34Model::<1000, f32>::build(dev);
model.download_model()?;
Now dfdx
seems to have a very complicated type system, so this is going to look really awful. I'll go over the worst bits.
pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
E: Dtype,
Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
AutoDevice: Device<E>,
{
model: <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built,
}
impl<E, const N: usize> Resnet34Model<N, E>
where
E: Dtype,
AutoDevice: Device<E>,
Resnet34<N>: BuildOnDevice<AutoDevice, E>,
{
pub fn build(dev: AutoDevice) -> Self {
let model = dev.build_module::<Resnet34<N>, E>();
Self { model }
}
}
The most important part here is the line:
model: <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built,
This line is the difference between the type that we specified, and the concrete type that gets built by the AutoDevice::build_module()
method. We can't just store a field with type Resnet34<NUM_CLASSES>
, that isn't usable directly. In particular, it doesn't have any notion of the datatype that will be used for the model, whether it is f32
or f64
. So we have to specify that Resnet34<NUM_CLASSES>
implements the BuildOnDevice<>
trait, and use the associated type, BuildOnDevice::Built
.
We ensure that our type does implement BuildOnDevice
with this line:
Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
We next need to ensure that the device supports the datatype we are using with the line:
AutoDevice: Device<E>,
The build()
method now just takes in the Device
we create in main.rs
, and constructs the model with the AutoDevice::build_module
method.
First source of frustration
Now, this looks like it should work. We've got a model that is relatively concise, and looks very similar to the ResNet-18 model.
So, why when I build this do I get the following horrendous error message?
error[E0277]: the trait bound `((dfdx::prelude::Conv2D<3, 64, 7, 2, 3>,
dfdx::prelude::BatchNorm2D<64>, ReLU, MaxPool2D<3, 2, 1>),
(dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64, 64, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<64>)>, ReLU, dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64,
64, 3, 1, 1>, dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<64>)>, ReLU, dfdx::prelude::Residual<(dfdx::prelude::Conv2D<64,
64, 3, 1, 1>, dfdx::prelude::BatchNorm2D<64>, ReLU, dfdx::prelude::Conv2D<64, 64, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<64>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<64,
128, 3, 2, 1>, dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1,
1>, dfdx::prelude::BatchNorm2D<128>), (dfdx::prelude::Conv2D<64, 128, 1, 2>,
dfdx::prelude::BatchNorm2D<128>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>, ReLU, dfdx::prelude::Conv2D<128, 128, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<128>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<128,
256, 3, 2, 1>, dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1,
1>, dfdx::prelude::BatchNorm2D<256>), (dfdx::prelude::Conv2D<128, 256, 1, 2>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>, ReLU, dfdx::prelude::Conv2D<256, 256, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<256>)>, ReLU), (GeneralizedResidual<(dfdx::prelude::Conv2D<256,
512, 3, 2, 1>, dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1,
1>, dfdx::prelude::BatchNorm2D<512>), (dfdx::prelude::Conv2D<256, 512, 1, 2>,
dfdx::prelude::BatchNorm2D<512>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<512, 512, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<512>)>, ReLU,
dfdx::prelude::Residual<(dfdx::prelude::Conv2D<512, 512, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<512>, ReLU, dfdx::prelude::Conv2D<512, 512, 3, 1, 1>,
dfdx::prelude::BatchNorm2D<512>)>, ReLU), (AvgPoolGlobal, dfdx::prelude::Linear<512, _>)):
BuildOnDevice<Cpu, _>` is not satisfied
Well our first clue to figure this monstrosity out comes from the little section at the end.
BuildOnDevice<Cpu, _>` is not satisfied
So, that means our giant model type, Resnet34
which just so happens to expand out to that awful tuple above, doesn't implement BuildOnDevice
. Well, in the previous section we just stated that we needed it to do just that.
The next clue comes a few lines down from that
= help: the following other types implement trait `BuildOnDevice<D, E>`:
()
(M1,)
(M1, M2)
(M1, M2, M3)
(M1, M2, M3, M4)
(M1, M2, M3, M4, M5)
(M1, M2, M3, M4, M5, M6)
Ahah! Evidently tuples of Models only implement BuildOnDevice
for varieties up to 6-tuples. We have an 8-tuple, and a 12-tuple! So it looks like we just need to split our too large tuples down into smaller bite sized pieces.
Lets do just that. Here is the new model now.
pub type Resnet34<const NUM_CLASSES: usize> = (
Head,
(
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
),
(
// tuples are only supported with up to 6 items in `dfdx`
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
(BasicBlock<128>, ReLU, BasicBlock<128>, ReLU),
),
(
// tuples are only supported with up to 6 items in `dfdx`
(
Downsample<128, 256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
(
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
),
(
Downsample<256, 512>,
ReLU,
BasicBlock<512>,
ReLU,
BasicBlock<512>,
ReLU,
),
Tail<NUM_CLASSES>,
);
And that builds! Excellent it was a simple fix.
Moving on to downloading models
Now that we have a model that can be concretely represented, and the code builds, we need to add some code to download the model files from Hugging Face.
I'm going to go ahead and refactor tardyai/src/download.rs
while I'm here, so we can reuse our old download logic.
// v--- refactor out this logic, so it's shorter in the other functions.
fn get_home_dir() -> Result<PathBuf, Error> {
let home = homedir::get_my_home()?
.expect("home directory needs to exist")
.join(".tardyai");
Ok(home)
}
pub fn untar_images(url: DatasetUrl) -> Result<PathBuf, Error> {
let home = get_home_dir()?;
let dest_dir = home.join("archive");
ensure_dir(&dest_dir)?;
let archive_file = download_file(url.url(), &dest_dir, None)?;
let dest_dir = home.join("data");
let dir = extract_archive(&archive_file, &dest_dir)?;
Ok(dir)
}
// v--- Add a crate public function that will download from a `ModelUrl`
pub(crate) fn download_model(url: ModelUrl) -> Result<PathBuf, Error> {
let home = get_home_dir()?;
let dest_dir = home.join("models");
ensure_dir(&dest_dir)?;
let model_file = download_file(url.url(), &dest_dir, Some(&format!("{url:?}.safetensors")))?;
Ok(model_file)
}
// v--- Change the name to something more generic
fn download_file(
url: String,
dest_dir: &Path,
// v--- This was needed because the filenames we download from Hugging Face
// are pretty ugly looking strings of hex digits.
default_name: Option<&str>,
) -> Result<PathBuf, Error> {
let mut response = reqwest::blocking::get(&url)?;
let file_name = default_name
.or(response.url().path_segments().and_then(|s| s.last()))
.and_then(|name| if name.is_empty() { None } else { Some(name) })
// v--- Add a new `Error` variant
.ok_or(Error::DownloadNameNotSpecified(url.clone()))?;
let downloaded_file = dest_dir.join(file_name);
// TODO: check if the archive is valid and exists
if downloaded_file.exists() {
log::info!("File already exists: {}", downloaded_file.display());
return Ok(downloaded_file);
}
log::info!("Downloading {} to: {}", &url, downloaded_file.display());
let mut dest = File::create(&downloaded_file)?;
response.copy_to(&mut dest)?;
Ok(downloaded_file)
}
And that's done, so let's create the download_models
method on our concrete model type. We can even make it call load_safetensors()
while we're at it.
impl<E, const N: usize> Resnet34Model<N, E>
where
E: Dtype + dfdx::tensor::safetensors::SafeDtype,
AutoDevice: Device<E>,
Resnet34<N>: BuildOnDevice<AutoDevice, E>,
{
// ...
pub fn download_model(&mut self) -> Result<(), Error> {
log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
let model_file = download_model(ModelUrl::Resnet34)?;
self.model.load_safetensors(&model_file)?;
Ok(())
}
}
So that's it, we're done!
Not so fast
Ah, it builds fine now, but I'm now getting this error when it runs.
➜ cargo run
Compiling tardyai v0.1.0 (/home/klah/git/articles/fastai-rust/tardyai/tardyai)
Compiling chapter1 v0.1.0 (/home/klah/git/articles/fastai-rust/tardyai/chapter1)
Finished dev [unoptimized + debuginfo] target(s) in 4.37s
Running `target/debug/chapter1`
[2023-11-22T00:01:42Z INFO tardyai::download] File already exists: /home/klah/.tardyai/archive/oxford-iiit-pet.tgz
[2023-11-22T00:01:42Z INFO tardyai::download] Extracting archive /home/klah/.tardyai/archive/oxford-iiit-pet.tgz to: /home/klah/.tardyai/data
[2023-11-22T00:01:42Z INFO tardyai::download] Archive already extracted to: /home/klah/.tardyai/data/oxford-iiit-pet/
[2023-11-22T00:01:42Z INFO chapter1] Images are in: /home/klah/.tardyai/data/oxford-iiit-pet/images
[2023-11-22T00:01:42Z INFO chapter1] Found 7390 files
[2023-11-22T00:01:48Z INFO tardyai::models::resnet] Downloading model from https://huggingface.co/microsoft/resnet-34/resolve/main/model.safetensors?download=true
[2023-11-22T00:01:49Z INFO tardyai::download] Downloading https://huggingface.co/microsoft/resnet-34/resolve/main/model.safetensors?download=true to: /home/klah/.tardyai/models/Resnet34.safetensors
Error:
0: Error with safetensors file: SafeTensorError(TensorNotFound("0.0.weight"))
Location:
chapter1/src/main.rs:38
Backtrace omitted. Run with RUST_BACKTRACE=1 environment variable to display it.
Run with RUST_BACKTRACE=full to include source snippets.
Well, dang it. Evidently safetensors
files have names for the layers that they are storing the weights for. I guess I'm going to have to figure out what this file actually contains, and load them individually into our model.
dfdx
supports safetensors
with the safetensors
crate. So I'll add that dependency and lets get to debugging.
This page of the safetensors
docs mentions using the memmap2
crate, so I'll go ahead and add that as well.
pub fn download_model(&mut self) -> Result<(), Error> {
log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
let model_file = download_model(ModelUrl::Resnet34)?;
// self.model.load_safetensors(&model_file)?;
let file = File::open(model_file).unwrap();
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
let tensors = SafeTensors::deserialize(&buffer).unwrap();
let mut names = tensors.tensors();
names.sort_by_key(|t| t.0.clone());
for (name, tensor) in names {
log::info!("Name: {name}: {:?}", tensor.shape());
}
Ok(())
}
Running this code gives us the following output.
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: classifier.1.bias: [1000]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: classifier.1.weight: [1000, 512]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.convolution.weight: [64, 3, 7, 7]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.embedder.embedder.normalization.weight: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.convolution.weight: [64, 64, 3, 3]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.0.normalization.weight: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.convolution.weight: [64, 64, 3, 3]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.bias: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.num_batches_tracked: []
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.running_mean: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.running_var: [64]
[2023-11-22T00:25:47Z INFO tardyai::models::resnet] Name: resnet.encoder.stages.0.layers.0.layer.1.normalization.weight: [64]
...
I've cut off most of the output, because it is an awful lot.
Well the shape of resnet.embedder.embedder.confolution.weight
is very similar to the convolution in Head
:
type Head = (
Conv2D<3, 64, 7, 2, 3>,
BatchNorm2D<64>,
ReLU,
MaxPool2D<3, 2, 1>,
);
So I bet they correlate, and this all looks fairly structured. stages
runs from 0
to 3
, so there are 4 stages, like we have 4 tuples of BasicBlock
s.
Stage 2 has the first layer
key running up to 5
, so it has 6 layers, which is the number of BasicBlocks
in the third tuple.
So I think we have the pattern pretty mapped out to what tensor goes where. Now we need to figure out how to actually load the tensors into the various weights, ideally without specifying each tuple entry manually.
This will probably be done with some form of TensorVisitor
. But this article is getting pretty long, so let's save that for next time.
Conclusion
In Part 4, we were able to construct the model, and download the weights in the form of a safetensors
file from Hugging Face. But we ran into an issue with actually loading the weights into the model, because they weren't named the same as what dfdx
expects. Check out the code for this part at github. Or check out the article-4
tag.
git co article-4
Stay tuned for Part 5 where we figure out how to solve this conundrum.
Top comments (0)