{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/htpu/barryhan.net/blob/main/usaaio/notebooks/03_transformers_imdb_finetune.ipynb)\n",
    "\n",
    "# Fine-tune DistilBERT on IMDB Sentiment\n",
    "\n",
    "End-to-end Hugging Face workflow: load IMDB, tokenize, fine-tune\n",
    "`distilbert-base-uncased`, evaluate, run inference on free-form text, save the model.\n",
    "\n",
    "**Runtime.** ~8-12 minutes on a Colab T4/L4 GPU. We subsample IMDB to keep\n",
    "training short while still getting clearly >90% accuracy.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Install + import"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Colab usually has transformers; datasets/evaluate may need install.\n",
    "import sys, subprocess\n",
    "def pip(*args):\n",
    "    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', *args], check=True)\n",
    "try:\n",
    "    import datasets, evaluate, transformers  # noqa\n",
    "except ImportError:\n",
    "    pip('transformers', 'datasets', 'evaluate', 'accelerate')\n",
    "    import datasets, evaluate, transformers  # noqa\n",
    "print('transformers', transformers.__version__)\n",
    "print('datasets    ', datasets.__version__)\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import (\n",
    "    AutoTokenizer, AutoModelForSequenceClassification,\n",
    "    TrainingArguments, Trainer, DataCollatorWithPadding,\n",
    ")\n",
    "import evaluate\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print('device:', device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load IMDB and subsample"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "raw = load_dataset('imdb')\n",
    "print(raw)\n",
    "\n",
    "# Subsample to ~5k train / 2k test for a fast Colab run.\n",
    "SEED = 42\n",
    "small_train = raw['train'].shuffle(seed=SEED).select(range(5000))\n",
    "small_test  = raw['test'].shuffle(seed=SEED).select(range(2000))\n",
    "print(small_train)\n",
    "small_train[0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Tokenize"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "MODEL = 'distilbert-base-uncased'\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
    "\n",
    "def tok(batch):\n",
    "    return tokenizer(batch['text'], truncation=True, max_length=256)\n",
    "\n",
    "train_tok = small_train.map(tok, batched=True, remove_columns=['text'])\n",
    "test_tok  = small_test.map(tok,  batched=True, remove_columns=['text'])\n",
    "train_tok = train_tok.rename_column('label', 'labels')\n",
    "test_tok  = test_tok.rename_column('label', 'labels')\n",
    "train_tok.set_format('torch'); test_tok.set_format('torch')\n",
    "print(train_tok.column_names)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Model"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=2)\n",
    "n = sum(p.numel() for p in model.parameters())\n",
    "print(f\"params: {n/1e6:.1f}M\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Metrics + Trainer"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "acc_metric = evaluate.load('accuracy')\n",
    "f1_metric  = evaluate.load('f1')\n",
    "\n",
    "def compute_metrics(pred_pair):\n",
    "    logits, labels = pred_pair\n",
    "    preds = np.argmax(logits, axis=-1)\n",
    "    return {\n",
    "        'accuracy': acc_metric.compute(predictions=preds, references=labels)['accuracy'],\n",
    "        'f1':       f1_metric.compute(predictions=preds, references=labels, average='binary')['f1'],\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir='distilbert-imdb',\n",
    "    num_train_epochs=2,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=64,\n",
    "    learning_rate=2e-5,\n",
    "    weight_decay=0.01,\n",
    "    eval_strategy='epoch',\n",
    "    save_strategy='epoch',\n",
    "    logging_steps=50,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model='accuracy',\n",
    "    report_to='none',\n",
    "    fp16=torch.cuda.is_available(),\n",
    "    seed=SEED,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_tok,\n",
    "    eval_dataset=test_tok,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Train"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "trainer.train()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Final evaluation"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "metrics = trainer.evaluate()\n",
    "metrics\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Inference on custom strings"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "clf = pipeline('text-classification', model=trainer.model, tokenizer=tokenizer,\n",
    "               device=0 if torch.cuda.is_available() else -1)\n",
    "LABELS = {'LABEL_0': 'negative', 'LABEL_1': 'positive'}\n",
    "\n",
    "samples = [\n",
    "    \"An absolute triumph -- gorgeous cinematography and a heartfelt script.\",\n",
    "    \"I wanted my two hours back. Dull, derivative, and the acting was wooden.\",\n",
    "    \"It was fine. Not great, not terrible, just there.\",\n",
    "    \"A masterclass in pacing and tension; one of the best films of the decade.\",\n",
    "]\n",
    "for s in samples:\n",
    "    out = clf(s, truncation=True)[0]\n",
    "    print(f\"[{LABELS.get(out['label'], out['label']):8s}  {out['score']:.3f}]  {s}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Save the model"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "SAVE_DIR = 'distilbert-imdb-final'\n",
    "trainer.save_model(SAVE_DIR)\n",
    "tokenizer.save_pretrained(SAVE_DIR)\n",
    "print('saved to', SAVE_DIR)\n",
    "import os\n",
    "for f in os.listdir(SAVE_DIR):\n",
    "    print(' ', f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What to try next\n",
    "\n",
    "- Train on the full 25k/25k IMDB split (10-15 min on T4) -- typically 93-94% accuracy.\n",
    "- Swap `distilbert-base-uncased` for `roberta-base` or `bert-base-uncased`.\n",
    "- Try LoRA / PEFT to fine-tune only a small adapter -- fits a larger base model in the same memory.\n",
    "- Build a calibration plot -- softmax outputs aren't always well-calibrated probabilities.\n",
    "\n",
    "Back to [Transformers](../transformers.html) on the USAAIO site.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
