{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/htpu/barryhan.net/blob/main/usaaio/notebooks/02_dl_cifar10_pytorch.ipynb)\n",
    "\n",
    "# CIFAR-10 Image Classification -- PyTorch\n",
    "\n",
    "A small CNN trained on CIFAR-10. We do two passes:\n",
    "\n",
    "1. **Baseline CNN** -- no augmentation, fast sanity check.\n",
    "2. **Augmented CNN** -- RandomCrop + HorizontalFlip + Cutout-style erasing, longer schedule.\n",
    "\n",
    "**Runtime.** ~3-5 minutes on a Colab T4/L4 GPU. Switch the runtime to GPU via\n",
    "`Runtime -> Change runtime type -> T4 GPU` before running.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Setup"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "SEED = 0\n",
    "torch.manual_seed(SEED); np.random.seed(SEED)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else\n",
    "                      ('mps' if torch.backends.mps.is_available() else 'cpu'))\n",
    "print('Using device:', device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data -- CIFAR-10\n",
    "\n",
    "60,000 32x32 color images in 10 classes. `torchvision` downloads the binary."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "CIFAR_MEAN = (0.4914, 0.4822, 0.4465)\n",
    "CIFAR_STD  = (0.2470, 0.2435, 0.2616)\n",
    "\n",
    "basic_tf = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n",
    "])\n",
    "\n",
    "train_ds = datasets.CIFAR10(root='./data', train=True,  download=True, transform=basic_tf)\n",
    "test_ds  = datasets.CIFAR10(root='./data', train=False, download=True, transform=basic_tf)\n",
    "\n",
    "CLASSES = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')\n",
    "print(f\"train={len(train_ds)}  test={len(test_ds)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Peek at a few training images.\n",
    "def unnormalize(img):\n",
    "    img = img.clone()\n",
    "    for c, (m, s) in enumerate(zip(CIFAR_MEAN, CIFAR_STD)):\n",
    "        img[c] = img[c] * s + m\n",
    "    return img.clamp(0, 1)\n",
    "\n",
    "fig, axes = plt.subplots(2, 6, figsize=(10, 4))\n",
    "for ax in axes.flat:\n",
    "    i = np.random.randint(len(train_ds))\n",
    "    img, lbl = train_ds[i]\n",
    "    ax.imshow(unnormalize(img).permute(1, 2, 0).numpy())\n",
    "    ax.set_title(CLASSES[lbl], fontsize=9)\n",
    "    ax.axis('off')\n",
    "plt.tight_layout(); plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "BATCH = 128\n",
    "train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2, pin_memory=True)\n",
    "test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. A small CNN\n",
    "\n",
    "3 conv blocks -> global avg pool -> linear. ~200K params, fits easily on T4."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "class SmallCNN(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super().__init__()\n",
    "        def block(in_c, out_c):\n",
    "            return nn.Sequential(\n",
    "                nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),\n",
    "                nn.BatchNorm2d(out_c),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),\n",
    "                nn.BatchNorm2d(out_c),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(2),\n",
    "            )\n",
    "        self.b1 = block(3, 32)    # 32 -> 16\n",
    "        self.b2 = block(32, 64)   # 16 -> 8\n",
    "        self.b3 = block(64, 128)  # 8  -> 4\n",
    "        self.pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.fc = nn.Linear(128, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.b1(x); x = self.b2(x); x = self.b3(x)\n",
    "        x = self.pool(x).flatten(1)\n",
    "        return self.fc(x)\n",
    "\n",
    "model = SmallCNN().to(device)\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"params: {n_params:,}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training loop with LR schedule"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def run_eval(model, loader):\n",
    "    model.train(False)\n",
    "    correct = total = 0\n",
    "    loss_sum = 0.0\n",
    "    crit = nn.CrossEntropyLoss(reduction='sum')\n",
    "    with torch.no_grad():\n",
    "        for x, y in loader:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            out = model(x)\n",
    "            loss_sum += crit(out, y).item()\n",
    "            correct += (out.argmax(1) == y).sum().item()\n",
    "            total += y.size(0)\n",
    "    return loss_sum / total, correct / total\n",
    "\n",
    "\n",
    "def train_one_epoch(model, loader, optimizer, scheduler):\n",
    "    model.train()\n",
    "    crit = nn.CrossEntropyLoss()\n",
    "    running = 0.0\n",
    "    for x, y in loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        out = model(x)\n",
    "        loss = crit(out, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        running += loss.item() * y.size(0)\n",
    "    return running / len(loader.dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "    optimizer, max_lr=0.1,\n",
    "    steps_per_epoch=len(train_loader), epochs=EPOCHS,\n",
    ")\n",
    "history = {'train_loss': [], 'val_loss': [], 'val_acc': []}\n",
    "\n",
    "t0 = time.time()\n",
    "for epoch in range(1, EPOCHS + 1):\n",
    "    tl = train_one_epoch(model, train_loader, optimizer, scheduler)\n",
    "    vl, va = run_eval(model, test_loader)\n",
    "    history['train_loss'].append(tl); history['val_loss'].append(vl); history['val_acc'].append(va)\n",
    "    print(f\"epoch {epoch:2d}  train_loss={tl:.3f}  val_loss={vl:.3f}  val_acc={va*100:5.2f}%  \"\n",
    "          f\"elapsed={time.time()-t0:5.1f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(11, 4))\n",
    "axes[0].plot(history['train_loss'], label='train'); axes[0].plot(history['val_loss'], label='val')\n",
    "axes[0].set_xlabel('epoch'); axes[0].set_ylabel('loss'); axes[0].legend(); axes[0].set_title('Loss')\n",
    "axes[1].plot([a*100 for a in history['val_acc']]); axes[1].set_xlabel('epoch')\n",
    "axes[1].set_ylabel('val accuracy (%)'); axes[1].set_title('Validation accuracy')\n",
    "plt.tight_layout(); plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Save & load a checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "ckpt = {'model_state': model.state_dict(), 'val_acc': history['val_acc'][-1]}\n",
    "torch.save(ckpt, 'cifar_small_cnn.pt')\n",
    "\n",
    "# Reload sanity check.\n",
    "fresh = SmallCNN().to(device)\n",
    "fresh.load_state_dict(torch.load('cifar_small_cnn.pt')['model_state'])\n",
    "print('reload val_acc:', run_eval(fresh, test_loader)[1])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Upgrade -- data augmentation\n",
    "\n",
    "`RandomCrop(32, padding=4)` + `RandomHorizontalFlip` is the standard CIFAR recipe,\n",
    "plus `RandomErasing` as a cheap stand-in for Cutout. We re-train from scratch\n",
    "and expect ~3-5% accuracy gain.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "aug_tf = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n",
    "    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2)),  # Cutout-like\n",
    "])\n",
    "aug_train = datasets.CIFAR10(root='./data', train=True, download=False, transform=aug_tf)\n",
    "aug_loader = DataLoader(aug_train, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "model2 = SmallCNN().to(device)\n",
    "EPOCHS2 = 12\n",
    "optimizer = torch.optim.SGD(model2.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "    optimizer, max_lr=0.1,\n",
    "    steps_per_epoch=len(aug_loader), epochs=EPOCHS2,\n",
    ")\n",
    "history2 = {'train_loss': [], 'val_loss': [], 'val_acc': []}\n",
    "t0 = time.time()\n",
    "for epoch in range(1, EPOCHS2 + 1):\n",
    "    tl = train_one_epoch(model2, aug_loader, optimizer, scheduler)\n",
    "    vl, va = run_eval(model2, test_loader)\n",
    "    history2['train_loss'].append(tl); history2['val_loss'].append(vl); history2['val_acc'].append(va)\n",
    "    print(f\"epoch {epoch:2d}  train_loss={tl:.3f}  val_acc={va*100:5.2f}%  elapsed={time.time()-t0:5.1f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(f\"Baseline final  val_acc = {history['val_acc'][-1]*100:.2f}%\")\n",
    "print(f\"Augmented final val_acc = {history2['val_acc'][-1]*100:.2f}%\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Confusion matrix on the test set"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "model2.train(False)\n",
    "preds, trues = [], []\n",
    "with torch.no_grad():\n",
    "    for x, y in test_loader:\n",
    "        out = model2(x.to(device))\n",
    "        preds.extend(out.argmax(1).cpu().tolist()); trues.extend(y.tolist())\n",
    "cm = pd.crosstab(pd.Series(trues, name='true'), pd.Series(preds, name='pred'))\n",
    "cm.index = CLASSES; cm.columns = CLASSES\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')\n",
    "plt.title('Test-set confusion matrix (augmented model)')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What to try next\n",
    "\n",
    "- Bigger model (ResNet-20, ~270K params) -- usually +5%.\n",
    "- MixUp / CutMix augmentation -- strong regularizer.\n",
    "- Cosine LR with warmup, longer training (50+ epochs).\n",
    "- Test-time augmentation (TTA): average predictions over flipped inputs.\n",
    "\n",
    "Back to [Deep Learning](../dl.html) on the USAAIO site.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
