<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Sam G.</title>
    <description>The latest articles on DEV Community by Sam G. (@sxg).</description>
    <link>https://dev.to/sxg</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F522546%2Fb6c9d716-5c46-429f-905e-e5de9a6ad91c.jpeg</url>
      <title>DEV Community: Sam G.</title>
      <link>https://dev.to/sxg</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/sxg"/>
    <language>en</language>
    <item>
      <title>Creating an MNIST Classifier with PyTorch Lightning</title>
      <dc:creator>Sam G.</dc:creator>
      <pubDate>Mon, 23 Oct 2023 15:20:00 +0000</pubDate>
      <link>https://dev.to/sxg/creating-an-mnist-classifier-with-pytorch-lightning-44ld</link>
      <guid>https://dev.to/sxg/creating-an-mnist-classifier-with-pytorch-lightning-44ld</guid>
      <description>&lt;p&gt;Basic image classification using the MNIST handwritten digit dataset is a solved problem in 2023, but this makes it perfect for learning some new techniques like PyTorch Lightning, which promises to standardize, simplify, and accelerate the way we create models with PyTorch. The MNIST dataset consists of handwritten digits from 0 through 9, and we should expect pretty high accuracy (&amp;gt;90%) despite training a new model from scratch. We’ll use this problem to implement PyTorch Lightning. I’ll walk through my super simple implementation, &lt;a href="https://github.com/sxg/PyTorch-Lightning-MNIST-Classifier"&gt;which is on GitHub&lt;/a&gt;. All of the code we write will be contained in just two files: one for setting up the data and one for setting up and training the model.&lt;/p&gt;

&lt;h2&gt;
  
  
  Setting Up the Environment
&lt;/h2&gt;

&lt;p&gt;Our core dependencies will by PyTorch and PyTorch Lightning, but we’ll also use TorchMetrics and TensorBoard for calculating and visualizing metrics, images, text and other data as we train and test our model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;torch torchvision  &lt;span class="c"&gt;# PyTorch&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;lightning  &lt;span class="c"&gt;# PyTorch Lightning&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;torchmetrics  &lt;span class="c"&gt;# TorchMetrics&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;tensorboard  &lt;span class="c"&gt;# TensorBoard&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h2&gt;
  
  
  MNIST Data
&lt;/h2&gt;

&lt;p&gt;The MNIST handwritten digit dataset consists of 60,000 training images and 10,000 test images. Each handwritten digit is a grayscale 28 px × 28 px image (i.e., 784 total pixels) and is directly available from &lt;code&gt;torchvision.datasets&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://res.cloudinary.com/practicaldev/image/fetch/s--lpbuKbF6--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/l7ziurus4mzd80yv1eyb.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://res.cloudinary.com/practicaldev/image/fetch/s--lpbuKbF6--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/l7ziurus4mzd80yv1eyb.jpeg" alt="Sample images of handwritten digits from the MNIST database." width="557" height="327"&gt;&lt;/a&gt; Sample images of handwritten digits from the MNIST database. &lt;a href="https://en.wikipedia.org/wiki/MNIST_database#/media/File:MnistExamplesModified.png"&gt;Image courtesy of Wikipedia&lt;/a&gt;.&lt;/p&gt;
&lt;h3&gt;
  
  
  &lt;code&gt;LightningDataModule&lt;/code&gt;
&lt;/h3&gt;

&lt;p&gt;All logic related to setting up the data is in &lt;a href="https://github.com/sxg/PyTorch-Lightning-MNIST-Classifier/blob/main/DigitDataModule.py"&gt;DigitDataModule.py&lt;/a&gt;, which contains our custom &lt;code&gt;LightningDataModule&lt;/code&gt; subclass. Let’s start from the top!&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;lightning.pytorch&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pl&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.utils.data&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torchvision.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MNIST&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torchvision.transforms&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DigitDataModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pl&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LightningDataModule&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="n"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;
        &lt;span class="c1"&gt;# Setup the transforms
&lt;/span&gt;        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="p"&gt;[&lt;/span&gt;
                &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
                &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,)),&lt;/span&gt;
            &lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;After all the imports, we have our &lt;code&gt;DigitDataModule&lt;/code&gt; subclass of &lt;code&gt;LightningDataModule&lt;/code&gt; along with the &lt;code&gt;__init__()&lt;/code&gt; method. We’ll initialize the data module with a default batch size of 64 and a couple of basic transforms. Because the MNIST handwritten digit classification is fairly simple, we don’t need to do any data augmentation or substantial image pre-processing. We’ll combine two transforms with the &lt;code&gt;transforms.Compose()&lt;/code&gt; function, and we’ll start with &lt;code&gt;transforms.ToTensor()&lt;/code&gt;, which converts the PIL (Python Image Library) &lt;code&gt;Image&lt;/code&gt; to a PyTorch &lt;code&gt;Tensor&lt;/code&gt; with values from 0 through 1. The second transform, &lt;code&gt;transforms.Normalize((0.5,) (0.5,))&lt;/code&gt;, normalizes the grayscale values in the &lt;code&gt;Tensor&lt;/code&gt; with a mean of 0.5 and standard deviation of 0.5 using this function:&lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Normalized Value=Unnormalized Value−MeanStandard Deviation
\text{Normalized Value} = \frac{\text{Unnormalized Value} - \text{Mean}}{\text{Standard Deviation}}
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Normalized Value&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Standard Deviation&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Unnormalized Value&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Mean&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;This will change the values in the image tensor from 0 through 1 to −1 through 1, which can often help with model convergence. However, because our task is simple and uses 1-channel grayscale values, I didn’t find a huge difference between including and excluding the normalization transformation.&lt;/p&gt;

&lt;p&gt;Up next, we have the &lt;code&gt;prepare_data()&lt;/code&gt; and &lt;code&gt;setup()&lt;/code&gt; methods:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;prepare_data&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;MNIST&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"MNIST"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;download&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;MNIST&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"MNIST"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;download&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;setup&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stage&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;stage&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s"&gt;"fit"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;full_set&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MNIST&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;root&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s"&gt;"MNIST"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;train_set_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;full_set&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;val_set_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;full_set&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;train_set_size&lt;/span&gt;
        &lt;span class="n"&gt;seed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="n"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;  &lt;span class="c1"&gt;# Split train/val datasets
&lt;/span&gt;            &lt;span class="n"&gt;full_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_set_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_set_size&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;elif&lt;/span&gt; &lt;span class="n"&gt;stage&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s"&gt;"test"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;test_set&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MNIST&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;root&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s"&gt;"MNIST"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;code&gt;prepare_data()&lt;/code&gt; is an appropriate place to download your data as PyTorch Lightning only calls this on a single process on the CPU. We’ll use &lt;code&gt;torchvision.datasets.MNIST&lt;/code&gt; to download the training and test data. After that, &lt;code&gt;setup()&lt;/code&gt; will create our datasets on the GPU (if available). For the fitting stage (i.e., training and validation), we’ll load the MNIST data that we downloaded into the &lt;code&gt;full_set&lt;/code&gt; variable, then we’ll calculate the sizes of our training and validation datasets using a standard 80%/20% split between them. We’ll use PyTorch’s &lt;code&gt;torch.utils.data.random_split()&lt;/code&gt; function to divide the full dataset into the training and validation datasets. For the test stage, we’ll simply load the dataset—no need to split any data here.&lt;/p&gt;

&lt;p&gt;Finally, we have the data loaders:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_workers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;val_dataloader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_workers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_dataloader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;test_set&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_workers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Here, we simply create &lt;code&gt;torch.utils.data.DataLoader&lt;/code&gt; instances for the training, validation, and test datasets. PyTorch Lightning will call the appropriate data loader function for us when we’re fitting and testing. Note that 10 workers is appropriate for my computer (M1 Max MacBook Pro using MPS) but will probably not be optimal for you. Usually, PyTorch will give you a warning and tell you the appropriate number of workers your computer can handle.&lt;/p&gt;

&lt;h2&gt;
  
  
  Implementing the Model
&lt;/h2&gt;

&lt;h3&gt;
  
  
  &lt;code&gt;LightningModule&lt;/code&gt;
&lt;/h3&gt;

&lt;p&gt;All logic related to setting up the model is in &lt;a href="https://github.com/sxg/PyTorch-Lightning-MNIST-Classifier/blob/main/DigitModule.py"&gt;DigitModule.py&lt;/a&gt; which contains our custom &lt;code&gt;LightningModule&lt;/code&gt; subclass. Let’s start with the model, loss function, and optimizer:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;lightning.pytorch&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pl&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torchmetrics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torchvision.utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;make_grid&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;DigitDataModule&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DigitDataModule&lt;/span&gt;

&lt;span class="c1"&gt;# The LightningModule
&lt;/span&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DigitModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pl&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LightningModule&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="n"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;784&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loss_fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_logged_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;valid_logged_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;test_logged_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;configure_optimizers&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x_reshaped&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;784&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_reshaped&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;After the imports, we have our &lt;code&gt;__init__()&lt;/code&gt; method, which sets up our model and loss function. The model consists of four linear layers, starting with a 784 neuron layer that accepts a flattened version of our MNIST images (remember our images are 28 px × 28 px, which is 784 total pixels). We’ll have two hidden layers with 128 and 64 neurons, respectively. Last, we’ll have our output layer with 10 neurons, which represent each of the 10 possible digit classes (i.e., 0 through 9).&lt;/p&gt;

&lt;p&gt;We’ll use the cross-entropy loss function to convert the raw values from the 10-neuron output layer of the model into probabilities and then a predicted digit, which we compare to the actual digit. Then we’ll setup a few flags for logging that I’ll describe more below. Next, &lt;code&gt;configure_optimizers()&lt;/code&gt; sets up an Adam optimizer. Last, we’ll implement the &lt;code&gt;forward()&lt;/code&gt; method, which we use to flatten our images before feeding them to our model since our model expects a flat (1, 784) &lt;code&gt;Tensor&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Next, we have our implementations of the training, validation, and test steps: all we do is provide the logic for &lt;code&gt;training_step()&lt;/code&gt;, &lt;code&gt;validation_step()&lt;/code&gt;, and &lt;code&gt;test_step()&lt;/code&gt; and PyTorch Lightning does the rest! The implementations are nearly identical, so I’ll explain &lt;code&gt;training_step()&lt;/code&gt; and highlight the additional lines in &lt;code&gt;validation_step()&lt;/code&gt; and &lt;code&gt;test_step()&lt;/code&gt;. We'll start here:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;training_step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;
    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We start by extracting the input handwritten digit image as a &lt;code&gt;Tensor&lt;/code&gt; in &lt;code&gt;x&lt;/code&gt; and the true digit value in &lt;code&gt;y&lt;/code&gt; from &lt;code&gt;batch&lt;/code&gt;. Then, we’ll apply the model to &lt;code&gt;x&lt;/code&gt; and store the output. Next, we’ll calculate the loss using the cross-entropy loss function we defined above. After this, we’ll log the loss so we can visualize it over time in TensorBoard. TensorBoard will group logged data based on the text preceding the “/”. So the following line:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"train/loss"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log to TensorBoard's "train" section
&lt;/span&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;will appear in a section titled “train” in TensorBoard. Similarly, we’ll log not only the loss but also the accuracy in the validation and test steps:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"val/loss"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log to TensorBoard's "val" section
&lt;/span&gt;&lt;span class="n"&gt;acc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;accuracy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;task&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s"&gt;"multiclass"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"val/acc"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;acc&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log to Tensorboard's "val" section
&lt;/span&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We’ll use &lt;code&gt;torchmetrics.functional.accuracy()&lt;/code&gt; to calculate the accuracy of our model. The function will take the raw values from our model, and we define the task as “multiclass” with 10 possible classes as that best describes our problem. Multiclass means there are more than two possible classes of which our handwritten digits can only belong to one, which makes sense since we’re looking at digits from 0 through 9 and each image can only belong to one of the 10 classes. In contrast, a binary problem would simply have two mutually exclusive classes. A multilabel problem (like multiclass) will have more than two possible classes, but these are not mutually exclusive. A classic example of a multilabel problem is applying genres to movies: you can have three potential genres (e.g., action, comedy, and horror) and any movie may belong to none, some, or all of the genres.&lt;/p&gt;

&lt;p&gt;Last, we’re also going to log the actual MNIST handwritten digit images to TensorBoard:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="c1"&gt;# Log images to Tensorboard
&lt;/span&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_logged_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;img_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;make_grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logger&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;experiment&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_image&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="s"&gt;"train/inputs"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;img_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;current_epoch&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logger&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;experiment&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="s"&gt;"train/targets"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;current_epoch&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logger&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;experiment&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="s"&gt;"train/preds"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;preds&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;current_epoch&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_logged_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;span class="p"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We’ll use the &lt;code&gt;self.train_logged_images&lt;/code&gt; flag to ensure we only log one batch of images and labels. Then we have to manually calculate the predicted digits from the model’s raw output by using &lt;code&gt;torch.argmax()&lt;/code&gt;. We’ll create a grid of the handwritten digit images using &lt;code&gt;torchvision.utils.make_grid()&lt;/code&gt;, and then we log the images, actual labels, and predicted labels as we did before. Finally, we update the flag, which we reset using callback methods: &lt;code&gt;on_training_epoch_end()&lt;/code&gt;, &lt;code&gt;on_validation_epoch_end()&lt;/code&gt;, and &lt;code&gt;on_test_epoch_end()&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Putting It All Together
&lt;/h3&gt;

&lt;p&gt;The last bit of code is our &lt;code&gt;main()&lt;/code&gt; function, which is only 5 lines of code thanks to PyTorch Lightning!&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;...&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;main&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="c1"&gt;# The model
&lt;/span&gt;    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DigitModule&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# The data
&lt;/span&gt;    &lt;span class="n"&gt;dm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DigitDataModule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Train the model
&lt;/span&gt;    &lt;span class="n"&gt;trainer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pl&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;trainer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;datamodule&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dm&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Test the model
&lt;/span&gt;    &lt;span class="n"&gt;trainer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;test&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;datamodule&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dm&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s"&gt;"__main__"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;main&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We simply instantiate the model and our data, which we feed to the PyTorch Lightning &lt;code&gt;Trainer&lt;/code&gt;. The &lt;code&gt;Trainer&lt;/code&gt; figures out which data to use and how to apply the model based on all the methods we defined earlier. We call &lt;code&gt;fit()&lt;/code&gt; to train the model on 5 epochs (arbitrarily selected), and then &lt;code&gt;test()&lt;/code&gt; to run the model against our test data. Let’s take a look at the results!&lt;/p&gt;

&lt;h2&gt;
  
  
  TensorBoard Results
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Metrics
&lt;/h3&gt;

&lt;p&gt;Let’s turn over to TensorBoard, which you can start by running this in the folder with your training code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;tensorboard &lt;span class="nt"&gt;--logdir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;lightning_logs/
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://res.cloudinary.com/practicaldev/image/fetch/s--B1pv8922--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/u243npqpqlgfcnt1no85.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://res.cloudinary.com/practicaldev/image/fetch/s--B1pv8922--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/u243npqpqlgfcnt1no85.jpeg" alt="TensorBoard showing the data we’ve logged. I’ve pinned the validation loss and accuracy, the test accuracy, and the training loss to the top." width="800" height="547"&gt;&lt;/a&gt; TensorBoard showing the data we’ve logged. I’ve pinned the validation loss and accuracy, the test accuracy, and the training loss to the top.&lt;/p&gt;

&lt;p&gt;You can open TensorBoard in your browser by navigating to &lt;code&gt;localhost:6006&lt;/code&gt;, which will open a screen similar to the figure above. I’ve pinned four cards to the top, and let’s walk through them. At the bottom right, we have the training loss, which is generally decreasing as we perform more iterations through our data, which is good—our model (might) be learning! Taking a look at the validation loss and accuracy on the top row, we see that both metrics are steadily improving as we perform more iterations through our data, and the validation accuracy peaks at 95.7%! Given the consistent improvements in both the validation loss and accuracy, we could have continued training for more than 5 epochs, but I’ll leave things as is for now.&lt;/p&gt;

&lt;p&gt;Now—most importantly—we need to check our test accuracy in the bottom left card to ensure our model is able to generalize to data its never seen before. Amazingly the accuracy is 96.9%! This is in line with our validation accuracy and suggests that our model has truly learned how to recognize handwritten digits from the MNIST data!&lt;/p&gt;

&lt;h3&gt;
  
  
  Data Visualization
&lt;/h3&gt;

&lt;p&gt;Let’s take a look at some of the images we’ve logged:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://res.cloudinary.com/practicaldev/image/fetch/s--y2aYi0WC--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/szntx1c8c9unvsuveyt1.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://res.cloudinary.com/practicaldev/image/fetch/s--y2aYi0WC--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/szntx1c8c9unvsuveyt1.jpeg" alt="64 sample handwritten digit images from the MNIST data that we logged during the test step." width="800" height="458"&gt;&lt;/a&gt; 64 sample handwritten digit images from the MNIST data that we logged during the test step.&lt;/p&gt;

&lt;p&gt;By navigating over to the “IMAGES” tab along the top navigation bar in TensorBoard, we can see the sample handwritten digit images we logged using the &lt;code&gt;torchvision.utils.make_grid()&lt;/code&gt; function. Nothing too exciting in this case, but this can be a great way to find underlying issues with your data if you’re having problems. Let’s take a look at the predicted and actual labels for these images:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://res.cloudinary.com/practicaldev/image/fetch/s--FynNEXyI--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/5pu6zlpxo0zq7690kop5.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://res.cloudinary.com/practicaldev/image/fetch/s--FynNEXyI--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://dev-to-uploads.s3.amazonaws.com/uploads/articles/5pu6zlpxo0zq7690kop5.jpeg" alt="The 64 labels our model predicted for the images are in the top box, and the 64 actual ground truth labels from the MNIST data are in the bottom box." width="800" height="376"&gt;&lt;/a&gt; The 64 labels our model predicted for the images are in the top box, and the 64 actual ground truth labels from the MNIST data are in the bottom box.&lt;/p&gt;

&lt;p&gt;Navigating over to the “TEXT” tab in TensorBoard, we can find the predicted and actual ground truth labels for the 64 handwritten digit images we just saw. The top box shows our model’s predictions, and the bottom box shows the actual ground truth labels, which are in perfect agreement for this batch of data! Again, visualizing your data is a great way to find underlying issues if you’re running into problems training your model.&lt;/p&gt;

&lt;h2&gt;
  
  
  Conclusion
&lt;/h2&gt;

&lt;p&gt;We trained and tested a neural network from scratch using PyTorch Lightning on the MNIST handwritten digit dataset using just two files, and we achieved a test accuracy of 96.9%! I’ll continue building on this foundation by using new datasets (e.g., &lt;a href="https://stanfordmlgroup.github.io/competitions/chexpert/"&gt;CheXpert&lt;/a&gt; and &lt;a href="https://physionet.org/content/mimic-cxr/2.0.0/"&gt;MIMIC-CXR&lt;/a&gt; since I’m a radiologist) and new techniques (e.g., &lt;a href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html"&gt;Lightning CLI&lt;/a&gt;).&lt;/p&gt;

</description>
      <category>pytorch</category>
      <category>lightning</category>
      <category>mnist</category>
      <category>machinelearning</category>
    </item>
  </channel>
</rss>
