<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Tyler Hu</title>
    <description>The latest articles on DEV Community by Tyler Hu (@huuh).</description>
    <link>https://dev.to/huuh</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3930682%2Fc0110f19-d767-4093-a685-25d058888dc6.jpg</url>
      <title>DEV Community: Tyler Hu</title>
      <link>https://dev.to/huuh</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/huuh"/>
    <language>en</language>
    <item>
      <title>Kaggle Getting Started Competition -- Petals to the Metal</title>
      <dc:creator>Tyler Hu</dc:creator>
      <pubDate>Tue, 19 May 2026 08:41:57 +0000</pubDate>
      <link>https://dev.to/huuh/kaggle-getting-started-competition-petals-to-the-metal-mec</link>
      <guid>https://dev.to/huuh/kaggle-getting-started-competition-petals-to-the-metal-mec</guid>
      <description>&lt;p&gt;&lt;em&gt;The experimental environment for the project is Kaggle notebook&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;competition link : &lt;a href="https://www.kaggle.com/competitions/tpu-getting-started" rel="noopener noreferrer"&gt;https://www.kaggle.com/competitions/tpu-getting-started&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsdo7uir60zl8n04eqhyy.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsdo7uir60zl8n04eqhyy.png" alt=" " width="800" height="484"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Firstly, we should observe the data. There are 193 files of data provided in this competition, including 192 tfrec files and one csv file. The csv file is the sample file for submission, for reference only.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;data/
├── tfrecords-jpeg-192x192/
├── tfrecords-jpeg-224x224/
│   ├── test/
│   │   ├── 00-224x224-462.tfrec
│   │   ├── 01-224x224-462.tfrec
│   │   ├── ...
│   │   └── 15-224x224-452.tfrec
│   ├── train/
│   │   ├── 00-224x224-798.tfrec
│   │   ├── 01-224x224-798.tfrec
│   │   ├── ...
│   │   └── 15-224x224-783.tfrec
│   └── val/
│       ├── 00-224x224-232.tfrec
│       ├── 01-224x224-232.tfrec
│       ├── ...
│       └── 15-224x224-232.tfrec
├── tfrecords-jpeg-331x331/
├── tfrecords-jpeg-512x512/
└── sample_submission.csv
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We notice the data files are all tfrec files. The type of file is a professional data file for TensorFlow and this experimental architecture is PyTorch so we need data processing to transform tfrec files to Tensor format.&lt;/p&gt;

&lt;p&gt;Let's take a look at the libraries needed for this experiment.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;io&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;timm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchvision.transforms&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Dataset&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;pathlib&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Path&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tfrecord&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;glob&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  &lt;strong&gt;Reading Files&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;We should to know what the data looks like. The tfrec file actually is a binary file so we can't look images and labels directly. We need to  parse it.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;train_files&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;/kaggle/input/competitions/tpu-getting-started/tfrecords-jpeg-224x224/train/*.tfrec&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_files&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;   &lt;span class="c1"&gt;# output: 16
&lt;/span&gt;&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_files&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's store all training data from &lt;code&gt;tfrecords-jpeg-224x224&lt;/code&gt; folder into the &lt;code&gt;train_files&lt;/code&gt; list, then print its length to verify. It should be 16 because there are 16 files in &lt;code&gt;train&lt;/code&gt; folder, which means the training data is split into 16 slices, and each slice contains 798 data points except for the last one, which contains 783 data points.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;reader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfrecord&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tfrecord_loader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;data_path&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;train_files&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="n"&gt;index_path&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;description&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;{}&lt;/span&gt;   &lt;span class="c1"&gt;# empty dict, reads out raw field name
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;record&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;reader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
    &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Next, let's explain the three parameters of &lt;code&gt;tfrecord_loader&lt;/code&gt; method:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;We only use the first slices, which is the &lt;code&gt;00-224x224-798.tfrec&lt;/code&gt; file, so we assign the first element of the &lt;code&gt;train_files&lt;/code&gt; list to the &lt;code&gt;data_path&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;TFRecord allows us to use a index file to accelerate random access, but we don't have it so just set it up at None.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;description&lt;/code&gt; tell us that we need set what field and what type of field. We need to set this parameter because the tfrec file stores binary data and it has not type information, so the library doesn't know how to parse it. We need to tell the library which fields to read and their types. It is a good way that typing &lt;code&gt;{ }&lt;/code&gt;. It can read all of field so that we can know what field there are, then typing it.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;We can see the following output after printing it.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;dict_keys(['id', 'class', 'image'])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now, let's formally parse a file:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;description&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;class&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;int&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;id&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;reader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfrecord&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tfrecord_loader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;data_path&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;train_files&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="n"&gt;index_path&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;description&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;description&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;record&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;reader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;class&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;type&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
    &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now, we know the field name so we can formally define &lt;code&gt;description&lt;/code&gt; dictionary. But how about the type? Simple, just set everything to &lt;code&gt;byte&lt;/code&gt; and see whether it throws an error. The error message will tell you the right type. The output is as follows:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;[57]
&amp;lt;class 'bytes'&amp;gt;
25512
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  &lt;strong&gt;Dataset&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;OK, we have understood the data structure and now we can define the dataset class.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PetalsDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;type&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;root_dir&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;type_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Path&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;files_list&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;type_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/*.tfrec&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="n"&gt;description&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;class&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;int&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;id&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;!=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;id&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;byte&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;path&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;files_list&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;reader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfrecord&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tfrecord_loader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;description&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;record&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;reader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;dict&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;img&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;io&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;BytesIO&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])).&lt;/span&gt;&lt;span class="nf"&gt;convert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;RGB&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;img&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;type&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;img&lt;/span&gt;
        &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;record&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;class&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;img&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's pay attention to the part of data processing. As I said before, the training data is split into 16 partitions and partitions are stored in 16 &lt;code&gt;tfrec&lt;/code&gt; files respectively. So the total training data should include &lt;code&gt;16x798=12,768&lt;/code&gt; samples. But the last one just includes 783 samples so the actual training data includes 12,753 samples. My code contains two nested loops. The first one read each &lt;code&gt;tfrec&lt;/code&gt; files and assign the &lt;code&gt;tfrec&lt;/code&gt; files parsed to &lt;code&gt;reader&lt;/code&gt;. The second one iterate over the &lt;code&gt;reader&lt;/code&gt; object by &lt;code&gt;record&lt;/code&gt; and store each sample into &lt;code&gt;samples&lt;/code&gt; list, so the length of &lt;code&gt;samples&lt;/code&gt; list should be 12,753.&lt;/p&gt;

&lt;p&gt;At the same time, I define a inputting parameter, &lt;code&gt;type&lt;/code&gt;, which is used to specify whether the dataset is training, validation or test. Note that the test dataset don't have &lt;code&gt;class&lt;/code&gt; field, so we should drop the field while we define it.&lt;/p&gt;

&lt;p&gt;Then, we define &lt;code&gt;transfrom&lt;/code&gt; to transform images into Tensor format for model input.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;train_transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;RandomResizedCrop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;224&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;RandomHorizontalFlip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;RandomVerticalFlip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;RandomRotation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;45&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ColorJitter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hue&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;RandomGrayscale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;


    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.485&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.456&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.406&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
                        &lt;span class="n"&gt;std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.229&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.224&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.225&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="n"&gt;val_transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
    &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.485&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.456&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.406&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
                        &lt;span class="n"&gt;std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.229&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.224&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.225&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;I want to explain the series of reinforcement in &lt;code&gt;transform&lt;/code&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;RandomResizedCrop&lt;/code&gt;: This randomly crops a portion of image and resize it to 224x224. &lt;code&gt;scale&lt;/code&gt; means that the cropped portion is between 50% and 100% of the original image. （ This encourages model to recognize flowers from local feature rather than relying on the full image.）&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;RandomHorizontalFlip&lt;/code&gt;: This randomly flips the image horizontally with a 50% probability. ( This helps the model become invariant to orientation. )&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;RandomVerticalFlip&lt;/code&gt;: This randomly flips the image vertically with a 50% probability. ( This helps the model become invariant to orientation. )&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;RandomRotation&lt;/code&gt;: This randomly rotates the image within a range of -45 to 45 degrees. ( This helps the model become invariant to orientation. )&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;ColorJitter&lt;/code&gt;: This randomly changes image color. The parameter means that &lt;code&gt;(brightness, contrast, saturation, hue)&lt;/code&gt;.  ( This prevents the model from relying on fixed colors to identify flowers, since the same species can appear differently under various lighting conditions and shooting angles.  )&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;RandomGrayscale&lt;/code&gt;: This convers the image into grayscale image within a 10% probability. ( This encourages model identify flowers not only by color, but also by shape and texture )&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Next, we just need to define the dataset object and dataloader object.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/kaggle/input/competitions/tpu-getting-started/tfrecords-jpeg-224x224&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
&lt;span class="n"&gt;train_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PetalsDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;train_transform&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;val_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PetalsDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;val&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;val_transform&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;train_dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;val_dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;val_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  &lt;strong&gt;Model&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;We can define the model after we have defined the dataset and dataloader object. Let's try to build a CNN model at first:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;ConvBlock&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;in_channels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_channels&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Conv2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_channels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;in_channels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_channels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;out_channels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;padding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MaxPool2d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;CNN&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feature&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="nc"&gt;ConvBlock&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="nc"&gt;ConvBlock&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="nc"&gt;ConvBlock&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classification&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;104&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;feature&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;classification&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CNN&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Actually,  the model is based on the CNN code from Andrew Ng's PyTorch basics course. &lt;/p&gt;

&lt;h2&gt;
  
  
  &lt;strong&gt;Model Training&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;We should define the device, loss function and optimizer for model training.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;loss_function&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight_decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;scheduler&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lr_scheduler&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CosineAnnealingLR&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;T_max&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then we code the training loop.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_function&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;running_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
    &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch_idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;loss_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;running_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;batch_idx&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;batch_idx&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;avg_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;running_loss&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;
            &lt;span class="n"&gt;accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;100.&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total&lt;/span&gt;
            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;[&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;batch_idx&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; / &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="mi"&gt;399&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;] &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
                  &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;LOSS : &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | accuracy : &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;%&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;running_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Note that I print accumulated running loss every 20 batches. When working with other datasets, we should adjust the number according to the total batches in our dataloader.  This &lt;code&gt;train_dataloader&lt;/code&gt; contains 399 batches, so we print the accumulated running loss every 20 batches.&lt;/p&gt;

&lt;p&gt;Next, we define evaluate loop.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;test_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;correct&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;test_loader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;correct&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="mf"&gt;100.&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;correct&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now, we can start training model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;
&lt;span class="n"&gt;best_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;Epoch: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;train_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_function&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;scheduler&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Test Accuracy:&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;accuracy&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;best_accuracy&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;best_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;accuracy&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;save&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/kaggle/working/best_model.pth&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;  save the best model with acc=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Running the code above, the model with the highest accuracy rate will be saved. But the highest accuracy just is 26% probability approximately.  I was quite dispirited. But I have realized it is normal after some research. This is just a simple model I threw together without any hyperparameter tuning. So it is just a foundation of neural network.&lt;/p&gt;

&lt;h2&gt;
  
  
  &lt;strong&gt;Enhancement&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;If we want higher accuracy, we need to fine-tune a pretrained model. This CNN model we built only contains three convolutional layers, which is far from sufficient for complex image classification. The pretrained model is common in Kaggle competition. Next, we use the &lt;code&gt;convnext_base&lt;/code&gt; model, ConvNeXt is an enhanced version of CNN, proposed by Meta.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;timm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;convnext_base&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pretrained&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;104&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drop_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;drop_path_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Pretrained model has learned many feature about image, so it achieve a much high accuracy soon. Here I decrease the number of epochs to 10 times.&lt;/p&gt;

&lt;p&gt;Running the code with the new model again. The accuracy will increase to 91% probability, and the competition score can reach 0.90.&lt;/p&gt;

&lt;p&gt;But now we have another question, which is after we ran code, despite the accuracy is much high in training set even achieve 100%, the accuracy in validation set is just 90% probability. The gap is 10% probability. This is a classic case of  overfitting. Let's address it. I used a common and useful way -- &lt;code&gt;Mixup&lt;/code&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchvision.transforms.v2&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MixUp&lt;/span&gt;

&lt;span class="n"&gt;mixup&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MixUp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;104&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The principle is to blend two images together like a semi-transparent overlay, and mix the labels proportionally as well. Without &lt;code&gt;Mixup&lt;/code&gt;,  model tend to memorize each images rather than learning general feature.&lt;/p&gt;

&lt;p&gt;At the same time, we need to add the following code into loop:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_function&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;running_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
    &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch_idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mixup&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;      &lt;span class="c1"&gt;#&amp;lt;---
&lt;/span&gt;    &lt;span class="c1"&gt;#......
&lt;/span&gt;        &lt;span class="n"&gt;total&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;#Here should be target.argmax(1)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is because after &lt;code&gt;Mixup&lt;/code&gt;, the target is no longer a class label (eg. &lt;code&gt;[48, 77, 0, ...]&lt;/code&gt;), but a probability distribution over  classes (eg. &lt;code&gt;[[0,0,...,1,...,0], ...]&lt;/code&gt;)&lt;/p&gt;

&lt;p&gt;Now, the competition score can reach 0.96 ! It surpass 95% and achieve my expected perfectly.&lt;/p&gt;

&lt;h2&gt;
  
  
  &lt;strong&gt;Submission&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;Finally, let's take a look at the code for submission.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;all_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;all_preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;img_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;test_dataloader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;test_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="n"&gt;all_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;all_preds&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cpu&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="n"&gt;df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;id&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;all_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;label&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;all_preds&lt;/span&gt;
&lt;span class="p"&gt;})&lt;/span&gt;
&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_csv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;submission.csv&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;head&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Well, if you want to use the model saved, use the code follows:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;test_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;timm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;convnext_base&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pretrained&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;104&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drop_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;drop_path_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;state_dict&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/kaggle/input/models/huuhgodona/convnext-model/pytorch/default/1/best_convnext.pth&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;map_location&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;test_model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;test_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;test_model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;test_model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;A &lt;code&gt;pth&lt;/code&gt; file just save the parameter model need, so we should have defined the model we need, which matches the &lt;code&gt;pth&lt;/code&gt; file. Then we load the parameter into it, and the same applies to our own custom CNN.&lt;/p&gt;

</description>
      <category>ai</category>
      <category>beginners</category>
      <category>learning</category>
      <category>deeplearning</category>
    </item>
    <item>
      <title>My first Post of Kaggle</title>
      <dc:creator>Tyler Hu</dc:creator>
      <pubDate>Fri, 15 May 2026 05:39:07 +0000</pubDate>
      <link>https://dev.to/huuh/my-first-post-of-kaggle-39em</link>
      <guid>https://dev.to/huuh/my-first-post-of-kaggle-39em</guid>
      <description>&lt;p&gt;Hi, I'm a data science undergradates from China.🤗 I'm very happy to be here on DEV.TO to share my own thread. &lt;/p&gt;

&lt;p&gt;I'm taking AI-related course on DeepLearning.AI and I really looking forward to connecting some study partners! As I said before, I'm from China so I'm still learning and improving English. If my writing isn't natural enough, I hope you won't mind and will point out any problems I've made, I'm really appreciate the feedback.❤️&lt;/p&gt;

&lt;p&gt;Also, if you curious about anything regarding China, feel free to reach out.&lt;/p&gt;

&lt;p&gt;Anyway, back to the point. I'll be publishing a series of posts covering Kaggle Competition on DEV to document my learning process. Stay tuned!&lt;/p&gt;

</description>
      <category>ai</category>
      <category>beginners</category>
      <category>learning</category>
    </item>
  </channel>
</rss>
