Olympic PyTorch Documentation¶
Olympic implements a Keras-like API for PyTorch.
The goal of Olympic is to combine the joy of Pytorch’s dynamic graph execution with the joy of Keras’s high level abstractions for training. Concretely, Olympic contains:
- The
olympic.fit()
function. This implements a very similar API to Keras’smodel.fit
andmodel.fit_generator
methods in a more functional and less object-oriented fashion and spares you the effort of “hand-rolling” your own training loop. Callback
objects that perform functionality common to most deep learning training pipelines such as learning rate scheduling, model checkpointing and csv logging. These integrate intoolympic.fit()
and spare you the effort of writing boilerplate code.- Some helpful utility functions such as common metrics and some convenience layers from Keras that are missing in PyTorch.
About Olympic PyTorch¶
My first foray into deep learning code was Tensorflow. Myself (and many others) found Tensorflow to be powerful but unwieldy. Next I moved onto Keras, which is a brilliant library that makes deep learning very accessible as it strips away most of the boilerplate code.
As I started to want more control and to implement research architectures I turned to PyTorch as its dynamic graph and clean interface made it not only relatively easy to use but also fun. However I missed some of the abstractions and utilities of Keras.
There are other libraries similar to this one (notably ignite
and torchsample
) but they weren’t quite what I
wanted so I decided to make what I wanted myself. And by make I mean copy and paste from Keras (MIT license)
because don’t fix what ain’t broken.
Future development¶
I only intend to update this library sufficient to keep it compatible with the latest PyTorch and maintain feature parity with Keras Callbacks. I will not be adding any more features beyond what already exists.
Quickstart¶
This quickstart guide will give a minimal code example using Olympic. This example is also available as a Jupyter notebook at olympic-pytorch/notebooks/Quickstart.ipynb
First make all of the necessary imports.:
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from multiprocessing import cpu_count
import olympic
Create datasets.:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train = datasets.MNIST('', train=True, transform=transform, download=True)
val = datasets.MNIST('', train=False, transform=transform, download=True)
train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count())
val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count())
Define network.:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Instantiate network, loss and optimiser.:
model = Net()
optimiser = optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.CrossEntropyLoss()
Create desired callbacks.:
callbacks = [
# Evaluates every epoch on val_loader
olympic.callbacks.Evaluate(val_loader),
# Saves model with best val_accuracy
olympic.callbacks.ModelCheckpoint('model.pt', save_best_only=True, monitor='val_accuracy'),
# Logs all metrics
olympic.callbacks.CSVLogger('log.csv')
]
Call olympic.fit
:
olympic.fit(
model,
optimiser,
loss_fn,
dataloader=train_loader,
epochs=10,
metrics=['accuracy'],
callbacks=callbacks
)
You should see this output.:
Begin training...
Epoch 1: 26%|██▌ | 122/469 [00:03<00:09, 35.70it/s, loss=0.515, accuracy=0.867]
The network will train for 10 epochs. The current directory will contain both model.pt
and log.csv
which
should look something like this.:
epoch,accuracy,loss,val_accuracy,val_loss
1,0.7888348436389482,0.6585237751605668,0.9437,0.1692712503015995
2,0.9093039267945985,0.3049919113421491,0.9712,0.08768766190297901
3,0.9272832267235251,0.24685336495322713,0.9745,0.07711423026025295
4,0.9375388681592041,0.21396846514044285,0.9777,0.06789233392337338
5,0.9416588930348259,0.19915449465595203,0.9815,0.0603904211839661
6,0.9476168265813789,0.18155415136136735,0.9822,0.05375468297088519
7,0.9493048152096659,0.1694526430894571,0.984,0.04907846948835067
8,0.953008395522388,0.16376275851377356,0.9852,0.04469430861719884
9,0.9561122956645345,0.15457178367329621,0.9859,0.043301032841484996
10,0.9554237739872068,0.1532330308109522,0.9869,0.0410145413863007
Differences between Olympic and Keras¶
fit function instead of fit method. Evaluation is a callback rather than having its own API.
Olympic has a few key differences from Keras.
fit function not fit method()¶
This is mostly personal preference as I find this cleaner than creating a trainer
object, “compiling” it and then
calling trainer.fit(model)``the ``torchsample
library does this in order to more closely resemble Keras in which
you must make a model.compile
call.
Evaluation is just another Callback¶
In Keras the evaluation data is passed directly to the fit
or fit_generator
method of a model. However I find
it more consistent to have evaluation on another dataset to be implemented as a Callback.
fit()¶
The olympic.fit
function is the heart of this library and where all the good stuff happens. The aim of this function
is to avoid “hand-rolling” your own training loops and hence present a much cleaner interface like Keras or
Scikit-learn.
The pseudocode for fit
is very simple.:
def fit(model, optimiser, loss_fn, epochs, dataloader, callbacks, update_fn, update_fn_kwargs):
callbacks.on_train_begin()
for epoch in range(1, epochs+1):
callbacks.on_epoch_begin(epoch)
epoch_logs = dict()
for batch_index, batch in enumerate(dataloader):
batch_logs = dict(batch=batch_index)
callbacks.on_batch_begin(batch_index, batch_logs)
x, y = prepare_batch(batch)
loss, y_pred = update_fn(model, optimiser, loss_fn, x, y, epoch, **update_fn_kwargs)
batch_logs['loss'] = loss.item()
# Loops through all metrics
batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)
callbacks.on_batch_end(batch_index, batch_logs)
callbacks.on_epoch_end(epoch, epoch_logs)
callbacks.on_train_end()
The default update_fn
is just a regular gradient descent step (see below) but any callable with the right signature
can be passed. Alternate ``update_fn``s could be more involved such as adversarial training or the Model-Agnostic
Meta-Learning algorithm. For an example see fit/usage.
Using your own update_fn¶
TBC. See the repo oscarknagg/few-shot
for some examples.
evaluate()¶
The evaluate
function is a convenience to evaluate the performance of a model on a particular dataset via different
metrics. It can be incorporated into the training loop using the Evaluate callback.
Callbacks¶
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to
get a view on internal states and statistics of the model during training. You can pass a list of callbacks
(as the keyword argument callbacks) to the olympic.fit()
function. The relevant methods of the callbacks will then
be called at each stage of the training
Layers¶
This module contains some convenient layers that exist in Keras that do not exist (in such a convenient and readable form) in Pytorch.
Metrics¶
A metric is a function that is used to judge the performance of your model. Metric functions are to be supplied to
the olympic.fit()
function at training time.
A metric function is similar to a loss function, except that the results from evaluating a metric are not used when training the model.
You can either pass the name of an existing metric, or pass a PyTorch function.
Custom Metrics¶
Custom metrics can also be passed to olympic.fit
. Custom metrics must take (y_true, y_pred)
as arguments and
return a single float as output. You should be able to pass any PyTorch loss function as a custom metric.