DEV Community

Favil Orbedios
Favil Orbedios

Posted on

fast.ai Book in Rust - Chapter 2 - Part 1

Introduction

In Part 6, we finished Chapter 1. We didn't get to do everything that we'd wanted to because my beefy desktop was put out of commission. That's still the case. Luckily, in this chapter, we don't need to focus too much on the training aspect.

This chapter focuses on defining the DataLoader classes and a Bing Image Search downloader that is provided with the fastai library. We're not going to implement a Bing downloader. That is too much work for something that could be a crate on its own. Please feel free to write such a crate, though, the world could use one.

One thing I'm noticing in the chapter however, that I clearly forgot about in the last chapter, is splitting the dataset into different sets, the training set, the validation set, and finally that test set. They are all important, and they all have separate uses. So I think now might be a good time to go through our code and refactor it to split out the different sets.

I will also take this opportunity to clean things up and make the code more generic.

Code Clean Up

The first step is moving tardyai/src/datasets.rs to a new module. Since this is specifically a category based module that is categorizing the input some how, I created the category module, and moved the file to tardyai/src/category/datasets.rs.

The next module I created was for the labels. In tardyai/src/category/encoders.rs, because it is for encoding a category into a one-hot tensor like we did for is_cat

pub type LabelFn<Category> = dyn Fn(&Path) -> Category;
Enter fullscreen mode Exit fullscreen mode

This is the same function we were storing in our DirectoryImageDataset, I just changed it to use a generic return type instead of hard-coding it to bool. Though technically this isn't as generic as it could be, since we are assuming Paths. But for the foreseeable future, this is going to be coming directly from the hard drive. I don't have any plans to fetch inputs from a database or directly from the internet yet.

Now I want to be able to define a trait that will convert the Category into a one-hot encoded tensor. And since we are currently using bool, we can define the trait for bool while we're at it.

pub trait IntoOneHot<const N: usize>: Default {
    fn into_one_hot(&self, dev: &AutoDevice) -> Tensor<Rank1<N>, f32, AutoDevice>;
}

impl IntoOneHot<2> for bool {
    fn into_one_hot(&self, dev: &AutoDevice) -> Tensor<Rank1<2>, f32, AutoDevice> {
        let mut t = dev.zeros::<(Const<2>,)>();
        t[[0]] = !*self as usize as f32;
        t[[1]] = *self as usize as f32;
        t
    }
}
Enter fullscreen mode Exit fullscreen mode

Now we need to create the idea of a DataLoader. The data loader in the fastai is able to split the input data into the different training, validation and testing data sets. Let's define a DirectoryImageDataLoader

// We need to thread N and Category through our types
pub struct DirectoryImageDataLoader<'fun, const N: usize, Category> {
    training: DirectoryImageDataset<'fun, N, Category>,
    validation: DirectoryImageDataset<'fun, N, Category>,
    test: DirectoryImageDataset<'fun, N, Category>,
}

impl<'fun, const N: usize, Category: IntoOneHot<N>> 
    DirectoryImageDataLoader<'fun, N, Category> 
{
    // Builder pattern again, its just so good
    pub fn builder(
        parent: impl AsRef<Path>,
        dev: AutoDevice,
    ) -> image_data_loader::Builder<'fun, N, Category> {
        image_data_loader::Builder::new(parent.as_ref().to_owned(), dev)
    }

    // Accessor methods
    pub fn training(&self) -> &DirectoryImageDataset<'fun, N, Category> {
        &self.training
    }

    pub fn validation(&self) -> &DirectoryImageDataset<'fun, N, Category> {
        &self.validation
    }

    pub fn test(&self) -> &DirectoryImageDataset<'fun, N, Category> {
        &self.test
    }
}
Enter fullscreen mode Exit fullscreen mode

You should be able to see how I'm using a builder there, let's look at how that was made.

pub mod image_data_loader {
    // [snip use statements] 

    pub struct Builder<'fun, const N: usize, Category> {
        parent: PathBuf,
        dev: AutoDevice,
        splitter: Option<Box<dyn Splitter<PathBuf>>>,
        label_fn: Option<&'fun LabelFn<Category>>,
    }

    impl<'fun, const N: usize, Category: IntoOneHot<N>> Builder<'fun, N, Category> {
        pub fn new(parent: PathBuf, dev: AutoDevice) -> Self {
            Self {
                parent,
                dev,
                splitter: None,
                label_fn: None,
            }
        }

        // I'll talk about Splitter soon, I promise
        pub fn with_splitter(mut self, splitter: impl Splitter<PathBuf> + 'static) -> Self {
            self.splitter = Some(Box::new(splitter));
            self
        }

        pub fn with_label_fn(mut self, label_fn: &'fun LabelFn<Category>) -> Self {
            self.label_fn = Some(label_fn);
            self
        }

        pub fn build(self) -> Result<DirectoryImageDataLoader<'fun, N, Category>, Error> {
            let exts = image_extensions();

            let mut splitter = self
                .splitter
                .unwrap_or_else(|| Box::new(RatioSplitter::default()));
            // By default the label will return the default value regardless of input
            // This isn't useful, but I didn't want to make label_fn a required argument
            let label_fn = self.label_fn.unwrap_or(&|_| Default::default());

            // Walk the directory here, instead of the Dataset constructor
            // so we can split the datasets out
            let walker = WalkDir::new(self.parent).follow_links(true).into_iter();
            let files: Vec<_> = walker
                .filter_map(|entry| {
                    let entry = entry.ok()?;
                    entry
                        .path()
                        .extension()
                        .and_then(|ext| Some(exts.contains(ext.to_str()?)))?
                        .then_some(entry)
                })
                .map(|entry| entry.path().to_owned())
                .collect();
            let (training, validation, test) = splitter.split(files);
            let training = DirectoryImageDataset::new(
                &training, self.dev.clone(), label_fn
            )?;
            let validation = DirectoryImageDataset::new(
                &validation, self.dev.clone(), label_fn
            );
            let test = DirectoryImageDataset::new(&test, self.dev, label_fn)?;

            Ok(DirectoryImageDataLoader {
                training,
                validation,
                test,
            })
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

This walks the directory, and runs the Splitter on the files. Then constructs the datasets for use by the user.

Now let's talk about that Splitter, like I promised.

pub trait Splitter<T> {
    fn split(&mut self, files: Vec<T>) -> (Vec<T>, Vec<T>, Vec<T>);
}
Enter fullscreen mode Exit fullscreen mode

I've gone ahead and allowed it to be generic over the type of input, since this is super straightforward to implement. The first implementation of a Splitter we'll create is the RatioSplitter. This will take in default ratios for the sizes of the validation and test datasets, and randomly (but deterministically) divide the input list into the various datasets.

pub struct RatioSplitter {
    rng: rand::rngs::StdRng,
    validation: f32,
    test: f32,
}

impl RatioSplitter {
    pub fn with_seed_validation_test(seed: u64, validation: f32, test: f32) -> Self {
        assert!(validation + test < 1.0);
        assert!(validation >= 0.0);
        assert!(test >= 0.0);
        let rng = rand::SeedableRng::seed_from_u64(seed);
        Self {
            rng,
            validation,
            test,
        }
    }

    pub fn with_seed_validation(seed: u64, validation: f32) -> Self {
        Self::with_seed_validation_test(seed, validation, 0.0)
    }

    // By default I'm going to assume your test data is somewhere else.
    pub fn with_seed(seed: u64) -> Self {
        Self::with_seed_validation_test(seed, 0.2, 0.0)
    }
}

// Very secure default seed, but it's good for reproducibility, and it's what dfdx uses
impl Default for RatioSplitter {
    fn default() -> Self {
        Self::with_seed(0)
    }
}

//       v--- Ensure that our types are sortable, for deterministic ordering
impl<T: Ord> Splitter<T> for RatioSplitter {
    fn split(&mut self, mut files: Vec<T>) -> (Vec<T>, Vec<T>, Vec<T>) {
        // Sort the files for deterministic ordering
        files.sort();
        // Shuffle with our deterministic Rng
        files.shuffle(&mut self.rng);

        let validation = (files.len() as f32 * self.validation) as usize;
        let test = (files.len() as f32 * self.test) as usize;

        let validation: Vec<T> = files.drain(..validation).collect();
        let test: Vec<T> = files.drain(..test).collect();
        let training: Vec<T> = files;

        (training, validation, test)
    }
}
Enter fullscreen mode Exit fullscreen mode

Now we just need to change the DirectoryDataset, which I've also renamed to DirectoryImageDataset, since this only works for images.

pub struct DirectoryImageDataset<'fun, const N: usize, Category> {
    files: Vec<PathBuf>,
    dev: AutoDevice,
    label_fn: &'fun LabelFn<Category>,
    tensors: DashMap<PathBuf, Tensor<Rank3<3, 224, 224>, f32, AutoDevice>>,
}

impl<'fun, const N: usize, Category> DirectoryImageDataset<'fun, N, Category> {
    fn new(
        files: &[PathBuf],
        dev: AutoDevice,
        label_fn: &'fun LabelFn<Category>,
    ) -> Result<Self, Error> {
        Ok(Self {
            files: files.to_owned(),
            dev,
            label_fn,
            tensors: Default::default(),
        })
    }

    pub fn files(&self) -> &[PathBuf] {
        &self.files
    }
}
Enter fullscreen mode Exit fullscreen mode

And finally we just need to update the call site in chapter1/src/main.rs.

    let is_cat = |path: &Path| {
        path.file_name()
            .and_then(|n| n.to_str())
            .and_then(|n| n.chars().next().map(|c| c.is_uppercase()))
            .unwrap_or(false)
    };

    let dataset_loader = DirectoryImageDataLoader::builder(path, dev.clone())
        .with_label_fn(&is_cat)
        .build()?;
    let dataset = dataset_loader.training();
    log::info!("Found {} files", dataset.files().len());
Enter fullscreen mode Exit fullscreen mode

This code is just going to use the default Splitter, so we will see 20% of the files go to the validation set, and the final 80% go to the training set.

Conclusion

We've refactored out a DataLoader and cleaned up the interface to support labels other than bool. So that's cool.

As per the usual, the code is on Github, or you can checkout the tag chapter-2-1

Tune in next time and we are going to go over visualizing the datasets like fastai can do in jupyter. I've found a neat way to show images directly in the terminal.

Top comments (0)