DEV Community

Favil Orbedios
Favil Orbedios

Posted on • Updated on

Working through the fast.ai book in Rust - Part 2

Introduction

In Part 1 we introduced the dfdx crate. And we didn't get into any of the actually implementing any of the fast.ai book projects.

In Part 2 we are going to see how far we can get into chapter 1 of the book. Since this isn't python, and we don't have the fastai library, we are going to have to do everything ourselves.

If you want to follow along, and don't have a copy of the book, you can read it online for free here.

In particular this is what the book wants us to write:

from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'

def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=42,
    label_func=is_cat, item_tfms=Resize(224))

learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
Enter fullscreen mode Exit fullscreen mode

We can see from that, that it isn't much code. But the reason I don't like it, and the reason I'm writing this series is because it is just a bunch of magic. It gets you on your feet quickly, by hiding all the fun parts behind its façade.

In this sample we can see that it:

  1. Automatically downloads, and extracts the images to an images folder.
  2. Defines a label function
  3. Automatically loads the images from the path with the ImageDataLoaders
  4. Constructs a learner from a fully available resnet34 model with weights that are already downloaded.
  5. And runs a learning algorithm on it for a single cycle.

Now this is too much to cover in a single article, so I'm going to focus on 1. for Part 2.

Creating a new Rust package

I realized while writing this article, that my structure for my code needs refinement. So I'm going to throw away the old code, and construct a repo with a number of crates in a Rust workspace.

If you want to follow along, I've created a git repo called tardyai, where I will be committing all my code to.

To fetch the specific tag from the repo use the following command:

git clone --branch START_HERE https://github.com/favilo/tardyai.git
Enter fullscreen mode Exit fullscreen mode

That will download the repo and put you in the same starting point as me. Specifically, it contains a Rust workspace with two member crates: tardyai, and chapter1. Both of these are the default packages created by cargo new.

tardyai will be a small, incomplete port of the fastai library. It won't run any code itself, it just contains all the logic around downloading images, for now.

Let's add URLs

It would be very nice if we could take the same URLs that are in the python library and do the same thing in Rust.

I'm envisioning an interface similar the following. I'm adding this to our chapter1/src/main.rs file.

use std::path::PathBuf;

fn main() {
    let path: PathBuf = tardyai::untar_images(tardyai::Url::Pets)
        .join("images");
}
Enter fullscreen mode Exit fullscreen mode

Now to make that a reality, lets edit tardyai/src/lib.rs

use std::path::PathBuf;

pub enum Url {
    Pets,
}

pub fn untar_images(url: Url) -> PathBuf {
    todo!()
}
Enter fullscreen mode Exit fullscreen mode

This just panics, but at least everything compiles.

From here we need to convert that enum Url::Pets to an actual URL. For the fastai library this is https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz. So lets add some methods to the Url type to get a URL.

const S3_BASE: &str = "https://s3.amazonaws.com/fast-ai-";
const S3_IMAGE: &str = "imageclas/";

// v-- I decided that we need to derive some sane traits by default.
#[derive(Debug, Clone, Copy)]
pub enum Url {
    Pets,
}

impl Url {
    pub fn url(self) -> String {
        match self {
            Self::Pets => {
                format!("{S3_BASE}{S3_IMAGE}oxford-iiit-pet.tgz")
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

This defines the url() method, and I created a constant called S3_BASE in order to collect the common prefix. This will allow us to quickly add new paths, and their corresponding URLs.

Actually download something why don't you?

Now we need to actually connect to the internet and download our archive from S3. In order to do this I'm going to use the reqwest crate. This crate is the defacto crate for making HTTP requests. It offers both an async and a blocking API. We are going to be using the blocking API for now. (Maybe in a future article I'll convert everything over to async/await)

➜   cargo add reqwest -p tardyai -F blocking
    Updating crates.io index
      Adding reqwest v0.11.22 to dependencies.
             Features:
             + __tls
             + blocking
             + default-tls
             + hyper-tls
             + native-tls-crate
             + tokio-native-tls
             38 deactivated features
Enter fullscreen mode Exit fullscreen mode

This adds the latest version of reqwest with the blocking feature turned on.

Then we edit tardyai/src/lib.rs

pub fn untar_images(url: Url) -> PathBuf {
    let response = reqwest::blocking::get(url.url()).expect("get failed");
    // ...
}
Enter fullscreen mode Exit fullscreen mode

That .expect() looks pretty ugly. Let's clean that up with our own custom error type derived with the help of thiserror.

➜   cargo add -p tardyai thiserror
    Updating crates.io index
      Adding thiserror v1.0.50 to dependencies.
Enter fullscreen mode Exit fullscreen mode

NOTE: I'm going to stop writing down the steps to add a crate. They are almost always the same. Instead I'll mention the crate and any features we need to add to get it to work for us.

thiserror will let us create an error type that is portable, and works with some nice error reporting crates that I'll talk about later.

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("reqwest error: {0}")]
    Reqwest(#[from] reqwest::Error),

    #[error("io error: {0}")]
    IO(#[from] std::io::Error),
}
Enter fullscreen mode Exit fullscreen mode

Then we can change the signature of the untar_images.

pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
    let response = reqwest::blocking::get(url.url())?;
    log::info!("response: {:?}", response);

    Ok(todo!())
}
Enter fullscreen mode Exit fullscreen mode

So, that is us fetching the file from the URL. Of course this is useless to us as it stands, because we haven't saved it to the hard disk, but this will not use any bandwidth because we haven't fetched the body of the response.

Save it to the hard disk already

The fastai library fetches the archive files to ~/.fastai/archive/. I'm going to do the same thing, but in ~/.tardyai/archive/ instead.

So first we need to make sure that the directory exists. And we need to fetch the user's home in a cross platform manner. For that I'm using the homedir crate.

fn ensure_dir(path: &PathBuf) -> Result<(), Error> {
    if !path.exists() {
        std::fs::create_dir_all(path)?;
    }
    Ok(())
}

pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
    let dest_dir = homedir::get_my_home()?
        .expect("home directory needs to exist")
        .join(".tardyai")
        .join("archive");
    ensure_dir(&dest_dir)?;
    // ...
}
Enter fullscreen mode Exit fullscreen mode

This required creating a new variant for our Error enum. I called it Home.

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("reqwest error: {0}")]
    Reqwest(#[from] reqwest::Error),

    #[error("io error: {0}")]
    IO(#[from] std::io::Error),

    #[error("homedir error: {0}")]
    Home(#[from] homedir::GetHomeError),
}
Enter fullscreen mode Exit fullscreen mode

And to save it to disk let's create a new function.

fn download_archive(url: Url, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
    let mut response = reqwest::blocking::get(url.url())?;
    let archive_name = response
        .url()
        .path_segments()
        .and_then(|s| s.last())
        .and_then(|name| if name.is_empty() { None } else { Some(name) })
        .unwrap_or("tmp.tar.gz");

    let archive_file = dest_dir.join(archive_name);

    // TODO: check if the archive is valid and exists
    if archive_file.exists() {
        log::info!("Archive already exists: {}", archive_file.display());
        return Ok(archive_file);
    }

    log::info!(
        "Downloading {} to archive: {}",
        url.url(),
        archive_file.display()
    );
    let mut dest = File::create(&archive_file)?;
    response.copy_to(&mut dest)?;
    Ok(archive_file)
}
Enter fullscreen mode Exit fullscreen mode

We have the archive, now what?

Well, let's decompress and extract it of course. For decompression I'm going to use the flate2 crate, with the rust_backend feature. And for extracting the resulting tar file, I'll use the tar crate.

fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<(), Error> {
    let tar_gz = File::open(archive_file)?;
    let tar = GzDecoder::new(tar_gz);
    let mut archive = Archive::new(tar);

    log::info!(
        "Extracting archive {} to: {}",
        archive_file.display(),
        dest_dir.display()
    );
    archive.unpack(dest_dir)?;
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Very straightforward. However, this doesn't give us the same path that the Python version does. The python version returns the extracted path. So We're going to have to do that next.

fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
    let tar_gz = File::open(archive_file)?;
    let tar = GzDecoder::new(tar_gz);
    let mut archive = Archive::new(tar);

    log::info!(
        "Extracting archive {} to: {}",
        archive_file.display(),
        dest_dir.display()
    );
    let dir = {
        let entry = &archive
            .entries()?
            .next()
            .ok_or(Error::TarEntry("No entries in archive"))??;
        entry.path()?.into_owned()
    };
    let archive_dir = dest_dir.join(dir);
    if archive_dir.exists() {
        log::info!("Archive already extracted to: {}", archive_dir.display());
        return Ok(archive_dir);
    }

    let tar = archive.into_inner();
    let mut tar_gz = tar.into_inner();
    tar_gz.seek(io::SeekFrom::Start(0))?;
    let tar = GzDecoder::new(tar_gz);
    let mut archive = Archive::new(tar);
    archive.unpack(dest_dir)?;

    Ok(archive_dir)
}
Enter fullscreen mode Exit fullscreen mode

This is a hack that I'm using in order to fetch the first entry in the tar archive, which is generally the top level directory stored inside. Then I have to unwind all the seeking I did by unwrapping the inner Reader, seeking to 0, then reconstructing the archive.

If anyone knows of a more sane way to do this, please let me know in the comments.

This also required me to create another variant for our Error enum, TarEntry.

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("reqwest error: {0}")]
    Reqwest(#[from] reqwest::Error),

    #[error("io error: {0}")]
    IO(#[from] std::io::Error),

    #[error("homedir error: {0}")]
    Home(#[from] homedir::GetHomeError),

    #[error("tar entry error: {0}")]
    TarEntry(&'static str),
}
Enter fullscreen mode Exit fullscreen mode

I also threw in a condition to return early if the archive has already been extracted. In the future we may want to change this to use SHA-1 hashes to verify that the data is the same as what was downloaded.

Conclusion

Well, so far we've managed to download and extract our dataset to a centralized location. This is a good first step. The first line of our program looks very similar to that of the python version.

use std::path::PathBuf;

use color_eyre::eyre::{Context, Result};
use tardyai::{untar_images, Url};

fn main() -> Result<()> {
    env_logger::Builder::new()
        .filter_level(log::LevelFilter::Info)
        .init();
    color_eyre::install()?;

    let path: PathBuf = untar_images(Url::Pets)
        .context("downloading Pets")?
        .join("images");
    log::info!("Images are in: {}", path.display());
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

In Part 3, we will figure out how to turn our images on disk into an ExactSizeDataset that can provide the images as Tensor structs, with their associated labels, and enable batching and other useful functions.

And if you want to see the code from this stage, you can either fetch the article-2 tag from git with

git co article-2
Enter fullscreen mode Exit fullscreen mode

or browse it on github

Top comments (0)