An adaptation of the Introduction to PyTorch* Lightning tutorial using Intel® Gaudi® AI processors.
In this tutorial, we go over the basics of Lightning by preparing models to train on the MNIST Handwritten Digits dataset.
Setup
This tutorial requires some packages besides pytorch-lightning.
! pip install --quiet "torchvision" "torchmetrics"
Warning Running pip as the 'root' user can result in broken permissions and conflicting behavior with the system package manager. It is recommended to use a virtual environment instead: https://docs.python.org/3/tutorial/venv.html
[notice] A new release of pip available: 22.3 -> 22.3.1 [notice] To update, run: python3 -m pip install --upgrade pip
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.plugins import HPUPrecisionPlugin
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_lightning import LightningDataModule
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Simplest Example
Here’s the simplest most minimal example with just a training loop (no validation, no testing).
class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
By using the Trainer you automatically get:
- TensorBoard* logging
- Model checkpointing
- Training and validation loop
- Early-stopping
To enable PyTorch Lightning to use the HPU accelerator, simply provide accelerator="hpu" parameter to the Trainer class.
# Init our model
mnist_model = MNISTModel()
# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
# Initialize a trainer; ensure that the accelerator is set to 'hpu' to run on Gaudi HPU
trainer = Trainer(
accelerator='hpu', devices=1, precision=16,
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
)
# Train the model
trainer.fit(mnist_model, train_loader)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz 100%|
███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 29454338.05it/s] Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz 100%|
████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 5607614.75it/s] Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz 100%|
████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 9637637.26it/s] Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz 100%|
█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 16858875.02it/s] Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw Input op list would be overridden in opt_level O2 hmp:verbose_mode False hmp:opt_level O2 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: True, using: 1 HPUs Missing logger folder: /root/gaudi-tutorials/Lightning/Introduction/lightning_logs =============================HABANA PT BRIDGE CONFIGURATION =========================== PT_HPU_LAZY_MODE = 1 PT_HPU_LAZY_EAGER_OPTIM_CACHE = 1 PT_HPU_ENABLE_COMPILE_THREAD = 0 PT_HPU_ENABLE_EXECUTION_THREAD = 1 PT_HPU_ENABLE_LAZY_EAGER_EXECUTION_THREAD = 1 PT_ENABLE_INTER_HOST_CACHING = 0 PT_ENABLE_INFERENCE_MODE = 1 PT_ENABLE_HABANA_CACHING = 1 PT_HPU_MAX_RECIPE_SUBMISSION_LIMIT = 0 PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807 PT_HPU_MAX_COMPOUND_OP_SIZE_SS = 10 PT_HPU_ENABLE_STAGE_SUBMISSION = 1 PT_HPU_PGM_ENABLE_CACHE = 1 PT_HPU_ENABLE_LAZY_COLLECTIVES = 0 PT_HCCL_SLICE_SIZE_MB = 16 PT_HCCL_MEMORY_ALLOWANCE_MB = 384 PT_HPU_INITIAL_WORKSPACE_SIZE = 0 PT_HABANA_POOL_SIZE = 24 PT_HPU_POOL_STRATEGY = 5 PT_HPU_POOL_LOG_FRAGMENTATION_INFO = 0 PT_ENABLE_MEMORY_DEFRAGMENTATION = 1 PT_ENABLE_DEFRAGMENTATION_INFO = 0 PT_HPU_ENABLE_SYNAPSE_LAYOUT_HANDLING = 1 PT_HPU_ENABLE_SYNAPSE_OUTPUT_PERMUTE = 1 PT_HPU_ENABLE_VALID_DATA_RANGE_CHECK = 1 PT_HPU_FORCE_USE_DEFAULT_STREAM = 0 PT_RECIPE_CACHE_PATH = PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0 PT_HPU_DYNAMIC_MIN_POLICY_ORDER = 3,1 PT_HPU_DYNAMIC_MAX_POLICY_ORDER = 2,3,1 =============================SYSTEM CONFIGURATION ========================================= Num CPU Cores = 96 CPU RAM = 784300912 KB ============================================================================================ | Name | Type | Params -------------------------------- 0 | l1 | Linear | 7.9 K -------------------------------- 7.9 K Trainable params 0 Non-trainable params 7.9 K Total params 0.016 Total estimated model params size (MB) /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 235/235 [00:04<00:00, 50.40it/s, loss=0.416, v_num=0] `Trainer.fit` stopped: `max_epochs=3` reached. Epoch 2: 100%|
███████████████████████████████████████████████████████████████████████████████████████████████████████████| 235/235 [00:04<00:00, 50.35it/s, loss=0.416, v_num=0]
A More Complete MNIST Lightning Module Example
Let's dive in a bit deeper and write a more complete LightningModule for MNIST.
This time, we bake in all the dataset specific pieces directly in the LightningModule. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.
Note what the following built-in functions are doing:
- This is where we can download the dataset. We point to our desired dataset and ask Torchvision's MNIST dataset class to download it if the dataset isn't found there.
-
Note we do not make any state assignments in this function (that is, self.something = ...).
setup(stage) ⚙️
- Loads in data from the file and prepares PyTorch tensor datasets for each split (train, val, test).
- Setup expects a 'stage' arg, which is used to separate logic for 'fit' and 'test'.
- If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever None is passed to stage (or ignore it altogether and exclude any conditionals).
-
Note this runs across all GPUs and it is safe to make state assignments here.
- train_dataloader(), val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()
class LitMNIST(LightningModule):
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
super().__init__()
# Set our init args as class attributes
self.data_dir = data_dir
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Hardcode some dataset specific attributes
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
# Define PyTorch model
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, self.num_classes),
)
self.accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.accuracy(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.accuracy, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
Train the Model on Intel Gaudi Accelerators
Remember to enable PyTorch Lightning to use the HPU accelerator. Simply provide accelerator="hpu" parameter to the Trainer class.
model = LitMNIST()
# Initialize a trainer; ensure that the accelerator is set to 'hpu' to run on Gaudi HPU
trainer = Trainer(
accelerator='hpu', devices=1, precision=16,
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
)
trainer.fit(model)
Input op list would be overridden in opt_level O2 hmp:verbose_mode False hmp:opt_level O2 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: True, using: 1 HPUs | Name | Type | Params ---------------------------------------- 0 | model | Sequential | 55.1 K 1 | accuracy | Accuracy | 0 ---------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params 0.110 Total estimated model params size (MB) Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Epoch 0: 85%|███████████████████████████████████████████████████████████████████████████████████████████ | 200/235 [00:11<00:01, 18.03it/s, loss=0.544, v_num=1] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 0: 94%|
████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 220/235 [00:12<00:00, 18.00it/s, loss=0.544, v_num=1] Epoch 0: 100%|
█████████████████████████████████████████████████████████████████████████████| 235/235 [00:13<00:00, 18.00it/s, loss=0.51, v_num=1, val_loss=0.420, val_acc=0.886] Epoch 1: 85%|████████████████████████████████████████████████████████████████▋ | 200/235 [00:08<00:01, 23.58it/s, loss=0.381, v_num=1, val_loss=0.420, val_acc=0.886] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 1: 94%|
███████████████████████████████████████████████████████████████████████▏ | 220/235 [00:09<00:00, 23.22it/s, loss=0.381, v_num=1, val_loss=0.420, val_acc=0.886] Epoch 1: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.18it/s, loss=0.362, v_num=1, val_loss=0.311, val_acc=0.907] Epoch 2: 85%|
████████████████████████████████████████████████████████████████▋ | 200/235 [00:08<00:01, 23.73it/s, loss=0.319, v_num=1, val_loss=0.311, val_acc=0.907] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 2: 94%|
███████████████████████████████████████████████████████████████████████▏ | 220/235 [00:09<00:00, 23.39it/s, loss=0.319, v_num=1, val_loss=0.311, val_acc=0.907] Epoch 2: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.28it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919] Epoch 2: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.27it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919] `Trainer.fit` stopped: `max_epochs=3` reached. Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.25it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919]
Testing
To test a model, call trainer.test(model).
Or, if you have just trained a model, you can just call trainer.test() and Lightning will automatically test using the best-saved checkpoint (conditioned on val_loss).
trainer.test()
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1386: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced. rank_zero_warn( Restoring states from the checkpoint path at /root/gaudi-tutorials/Lightning/Introduction/lightning_logs/version_1/checkpoints/epoch=2-step=645.ckpt Loaded model weights from checkpoint at /root/gaudi-tutorials/Lightning/Introduction/lightning_logs/version_1/checkpoints/epoch=2-step=645.ckpt /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Testing DataLoader 0: 100%|
█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 21.17it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── val_acc 0.9214999675750732 val_loss 0.25735974311828613 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── [{'val_loss': 0.25735974311828613, 'val_acc': 0.9214999675750732}]
Code language: JavaScript (javascript)
You can use the TensorBoard magic function to view the logs that Lightning has created for you!
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.
License
Licensed under a CC BY SA 4.0 license.
A derivative of Introduction To PyTorch .Lightning by the PyTorch Lightning team.