Introduction
In Part 4, we managed to construct the model, and download the weights from Hugging Face.
In this part, I think I've figured out an ugly way to load those weights into the model. Let's see if we can do some classification with these weights and see if it gives something sensible at all.
Some cleanup first
So I'm going to admit that without the changes that are coming to dfdx
in the github main
branch, this is going to be remarkably hacky.
But before we get to that, I did some cleanup. While reading the fast.ai book, I stumbled across some terms that we were using incorrectly in our model. So in the interest of being true to the book, I've renamed Head
and Tail
to Stem
and Head
. In particular the head of the model is actually the last bit that performs the classification tasks. I figured it might be confusing to people who are reading this with a more ML background.
type Stem = (
Conv2D<3, 64, 7, 2, 3>,
BatchNorm2D<64>,
ReLU,
MaxPool2D<3, 2, 1>,
);
pub type Head<const NUM_CLASSES: usize> = (AvgPoolGlobal, Linear<512, NUM_CLASSES>);
I've also split Resnet18
and Resnet34
into two parts. The body, and a Head
, which we've previously discussed. This will hopefully make it easier to do transfer learning and change the shape of the head later down the road. I've only copied the Resnet34
version here, since it's what we're working on.
// Layer clusters are in groups of [3, 4, 6, 4]
pub type Resnet34Body = (
Stem,
(
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
BasicBlock<64>,
ReLU,
),
(
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
(BasicBlock<128>, ReLU, BasicBlock<128>, ReLU),
),
(
(
Downsample<128, 256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
(
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
BasicBlock<256>,
ReLU,
),
),
(
Downsample<256, 512>,
ReLU,
BasicBlock<512>,
ReLU,
BasicBlock<512>,
ReLU,
),
);
pub type Resnet34<const NUM_CLASSES: usize> = (Resnet34Body, Head<NUM_CLASSES>);
There, that looks better.
What about those weights?
Now we can start getting to the ugly stuff. The interface for visiting each of the tensors in a model, and updating them doesn't allow for a tree-like structure. It just iterates through the tensors one at a time, and passes the object you tell it to to your handler.
So I came up with the naive idea to just create a giant array of all of the names in the weights file, and manually sort them into the order that will be iterated over from the model we created. It's really not pretty, but it "works". (The scare quotes are there because it almost certainly is wrong somehow, and we'll have to wait until we try to classify something to see if these are sane values.)
I'm not going to paste the whole array, because it's > 200 lines, and that feels like a giant waste of space. If you want to go look at it, feel free to do so here.
pub(crate) const RESNET34_LAYERS: [&'static str; 254] = [
"resnet.embedder.embedder.convolution.weight",
"resnet.embedder.embedder.normalization.weight",
"resnet.embedder.embedder.normalization.bias",
"resnet.embedder.embedder.normalization.running_mean",
"resnet.embedder.embedder.normalization.running_var",
"resnet.embedder.embedder.normalization.num_batches_tracked",
"resnet.embedder.embedder.normalization.num_batches_tracked",
// [SNIP]
];
This snippet includes a part that I wanted to point out. The dfdx
BatchNorm2D
type contains two scalar values, epsilon
and momentum
, but the safetensors file downloaded from Hugging Face does not. It does have one scalar per normalization layer called num_batches_tracked
. So in the interest of expediency, I just duplicated this value. This is most certainly not the correct thing to do, but the values I saw during my debugging process were all zeros, so I'm hoping it doesn't matter in the long run.
Now let's see about the class that takes that and applies the tensors to the model.
pub(crate) struct NamedTensorVisitor<'a> {
names: &'a '[&'static str],
idx: usize,
tensors: &'a SafeTensors<'a>,
}
impl<'a> NamedTensorVisitor<'a> {
pub(crate) fn new(names: &'a [&'static str], tensors: &'a SafeTensors<'a>) -> Self {
Self {
names,
idx: 0,
tensors,
}
}
}
So this class is dead simple, just increase the index every time you visit a tensor or scalar, and use the name that is at that index, fetching the weights from the SafeTensors
field.
Implementing the TensorVisitor
trait requires us to add the num-traits
crate to our package, because we need to reference the NumCast
trait in the definition.
impl<'a, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for NamedTensorVisitor<'a> {
type Viewer = ViewTensorMut;
type Err = Error;
type E2 = E;
type D2 = D;
fn visit<S: Shape>(
&mut self,
_opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
log::debug!(
"Loading tensor shape: {:?}, {:?}",
t.shape(),
&self.names.get(self.idx)
);
t.load_safetensor(
&mut self.tensors,
&self.names.get(self.idx).ok_or(Error::NotEnoughNames)?,
)?;
self.idx += 1;
Ok(None)
}
fn visit_scalar<N: NumCast>(
&mut self,
_opts: ScalarOptions<N>,
n: <Self::Viewer as TensorViewer>::View<'_, N>,
) -> Result<Option<N>, Self::Err> {
log::debug!("Loading scalar: {:?}", &self.names.get(self.idx));
let tensor = self
.tensors
.tensor(self.names.get(self.idx).ok_or(Error::NotEnoughNames)?)?;
let data = tensor.data();
let mut array = [0; 8];
array.copy_from_slice(data);
let val = f64::from_le_bytes(array);
*n = N::from(val).ok_or(Error::NumberFormatException)?;
self.idx += 1;
Ok(None)
}
}
This trait is a little complicated. We'll start by going over the associated types.
Viewer
there is just assigned to ViewTensorMut
which is a marker type that means we will receive a &mut Tensor<>
in the visit
method.
Err
is just our crate's Error
enumeration.
E2
and D2
are there in case you want to make changes to the tensors, either to the datatype or moving it to a different device. Since we aren't returning anything, this could be any number of things, but I'll just keep it set to the same as the inputs.
visit()
is pretty straightforward, just load the tensor from the current name, and advance the index by one.
visit_scalar()
was an interesting one, and I'm copying from the dfdx
crate's implementation. Scalars in SafeTensors
objects are stored as f64
in little-endian format, but the only way we can get to the data is by copying it into a [u8; 8]
, and converting that to an f64
. Finally, we store it in the scalar by converting it to N
(f32
in our case) from the f64
, and advance the index by one.
Now we just have to actually use this struct to load up the tensors in the download_models
method we created in the previous part.
pub fn download_model(&mut self) -> Result<(), Error> {
log::info!("Downloading model from {}", ModelUrl::Resnet34.url());
let model_file = download_model(ModelUrl::Resnet34)?;
// TODO: Somehow make something like this work
// self.model.load_safetensors(&model_file)?;
let file = File::open(model_file).unwrap();
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
let tensors = SafeTensors::deserialize(&buffer).unwrap();
let _ = <<Resnet34<N> as BuildOnDevice<AutoDevice, E>>::Built as TensorCollection<
E,
AutoDevice,
>>::iter_tensors(&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
})?;
Ok(())
}
That complicated generic type after we load the tensors from the file really just boils down to interpretting the Built
type, that we store as our model instance, as a TensorCollection<>
, and calling the iter_tensors()
associated function on that type.
iter_tensors()
takes a ModuleVisitor
. I'm using the included RecursiveWalker
that just recurses through the tensors in the model, and calls visit()
and visit_scalar()
on them.
Actually, let's go ahead and clean that complicated generic type up a bit.
type Resnet34Built<const NUM_CLASSES: usize, E> =
<Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built;
This will let us clean up our model type definition.
pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
E: Dtype + SafeDtype,
Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
AutoDevice: Device<E>,
{
pub model: Resnet34Built<NUM_CLASSES, E>,
}
And then actually clean up the ugly iter_tensors()
call
let _ = <Resnet34Built<N, E> as TensorCollection<E, AutoDevice>>::iter_tensors(
&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
},
)?;
That looks better. And we finally have the weights loaded into the model!
What can we do with it?
Let's try to use the model as it is to classify some of our Pets
images, and see if the tensor it returns correlates to the cat
or dog
classes in the ImageNet dataset collection.
To apply the model to an image just means we call forward()
on the tensor from the image. So let's do just that.
model.download_model()?;
let (image, is_cat) = dataset.get(1)?;
let categories = model.model.forward(image);
log::info!("Is Cat? {}", is_cat);
log::trace!("Categories: {:#?}", categories.array());
let max_category = categories
.softmax()
.array()
.into_iter()
.map(ordered_float::OrderedFloat)
.enumerate()
.max_by_key(|t| t.1)
.unwrap();
log::info!("(Category, Weight): {:?}", max_category);
This fetches the 2nd image in our dataset which, for me at least, is a cat. Then we apply the model to the tensor.
I've added a log statement for the categories tensor that is returned. This is a Rank1<1000>
dimensional tensor that is supposed to have the largest value in the index that corresponds to the category of object in the input image.
So let's figure out which that is. I apply softmax()
to the categories, which takes the ugly output of the original categories
tensor, and massages the values so that they all add up to 1, while keeping the largest value the same. Unfortunately, f32
in rust doesn't implement the Ord
trait, so we can't just fetch the maximum value from the tensor directly, we have to convert it to a type that does implement Ord
. That is where the ordered-float
crate comes into play. As long as we don't have any infinities or NaN
s in our data, we can trust a total ordering.
And now if we run this we get (89, OrderedFloat(0.0099451095))
. Hmm, if the weight of the category is supposed to correlate to how certain the model is, 0.9% doesn't inspire me with confidence. Let's check what that category is supposed to represent.
I found this json file that is supposed to be the categories that the original ResNet models were trained on. If we look in there for index 89, we get... sulphur-crested_cockatoo
...
Well, that's a little frustrating, I don't think that is a cat. Something tells me that our model may not have been set up correctly. I'll attempt to get it working between articles. But this has taken a long time, and we may as well just train the model from scratch to see if we even have anything worth saving.
I'm going to plan on the next article talking about using the model in a learner with an optimizer like the penultimate line of the python code we're targetting.
learn = vision_learner(dls, resnet34, metrics=error_rate)
If I get this model working with transfer learning in the meantime, I'll switch it up and talk about how I managed it.
Anyway, until then, the code is available on github as always and if you're following along in your own editor, you can checkout the article-5
tag.
git pull origin
git checkout article-5
Stay tuned, and I hope to see you in the next article with a learner in tow.
Top comments (0)