%reload_ext autoreload
%autoreload 2
%matplotlib inline
An infinitely customizable training loop
In this section we're going to create a highly customizable training loop using callbacks.
Get some data and put it in some Datasets
#export
from exp.nb_03 import *
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import ipywidgets as widgets
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
loss_func = F.cross_entropy
nh = 50
bs = 64
c = 10
Our previous fit()
was overloaded with components.
Let's package some of those components using factory methods and clean up our training loop.
First, we'll bunch the data together into a DataBunch
class. This will enable us to hold everything we need data-wise for training in one object - data
- that can then be pulled from depending on the mode (training or validation).
#export
class DataBunch():
def __init__(self, train_dl, valid_dl, c=None):
self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
@property
def valid_ds(self): return self.train_dl.dataset
@property
def train_ds(self): return self.train_dl.dataset
We can now define a data
object which we can index into to grab batches of train or valid.
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)
The @property
decorator allows us to define the getter and setter properties of the dataloaders. This is a convienence function.
xb, yb = next(iter(data.train_dl))
xb.shape
class person:
def __init__(self, name="Guest"):
self.__name=name
def setname(self, name):
self.__name=name
def getname(self):
return self.__name
p1=person()
p1.setname('Bill')
p1.getname()
class person:
def __init__(self):
self.__name=''
def setname(self, name):
print('setname() called')
self.__name=name
def getname(self):
print('getname() called')
return self.__name
name=property(getname, setname)
p1=person()
p1.name="Steve"
p1.name
class person:
def __init__(self):
self.__name= 'Not Implemented'
@property
def name(self):
return self.__name
@name.setter
def name(self, value):
self.__name=value
p1 = person()
p1.name
p1.name = 'Ted'
p1.name
#export
def get_model(data, lr=0.5, nh=50):
features = data.train_ds.x.shape[1]
model = nn.Sequential(nn.Linear(features, nh), nn.ReLU(), nn.Linear(nh, data.c))
return model, optim.SGD(model.parameters(), lr=lr)
The Learner
is basically our highest level abstraction.
It is just a container - like DataBunch
for the data - with no logic, for everything we need in our loop.
#export
class Learner():
def __init__(self, model, opt, loss_func, data):
self.model = model
self.opt = opt
self.loss_func = loss_func
self.data = data
learn = Learner(*get_model(data), loss_func, data)
Now let's see what our fit looks like when now that everything is encapsulated in one object.
def fit(learner, epochs):
for epoch in range(epochs):
learner.model.train()
for xb,yb in learner.data.train_dl:
loss = learner.loss_func(learner.model(xb), yb)
loss.backward()
learner.opt.step()
learner.opt.zero_grad()
learner.model.eval()
with torch.no_grad():
tot_loss,tot_acc = 0.,0.
for xb,yb in learner.data.valid_dl:
pred = learner.model(xb)
tot_loss += learner.loss_func(pred, yb)
tot_acc += accuracy (pred,yb)
nv = len(learner.data.valid_dl)
print(epoch, tot_loss/nv, tot_acc/nv)
return tot_loss/nv, tot_acc/nv
loss, acc = fit(learn, 1)
Let's abstract away a bit more of training loop and insert some callbacks!
We'll again strip away as much as we can from the fit
function. For each epoch in epochs we want to make two calls: process all batches in the training set and all batches in the validation set.
Those calls will pass along the appropriate data and then they will be processed by all_batches
and then one_batch
.
This is Factory Method refactoring.
Without the callbacks it will look something like this:
def one_batch(xb,yb):
pred = model(xb)
loss = loss_func(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()
def all_batches(dl):
for b in dl:
one_batch(*b)
def fit():
for epoch in range(epochs):
all_batches(learn.data.train_dl)
with torch.no_grad():
all_batches(learn.data.valid_dl)
def one_batch(xb, yb, cb):
if not cb.begin_batch(xb, yb): return # if returns False stop
loss = cb.learn.loss_func(cb.learn.model(xb), yb)
if not cb.after_loss(loss): return
loss.backward()
if cb.after_backward(): cb.learn.opt.step()
if cb.after_step(): cb.learn.opt.zero_grad()
def all_batches(dl, cb):
for xb,yb in dl:
one_batch(xb, yb, cb)
if cb.do_stop(): return # if do_stop is True stop the loop
def fit(epochs, learn, cb):
if not cb.begin_fit(learn): return # the callbackhandler gets passed the learner
for epoch in range(epochs):
if not cb.begin_epoch(epoch): continue
all_batches(learn.data.train_dl, cb)
if cb.begin_validate():
with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
if cb.do_stop() or not cb.after_epoch(): break
cb.after_fit()
Now the key to this refactoring is that there is this cb
object that gets passed to fit
and is used everywhere. This is going to be the CallbackHandler
Callbacks are used frequently when doing backend development. Webframeworks built Node.js are often filled with callbacks.
They are less often used in Python ML development. So let's make some simple examples to demonstrate the idea.
def f(o): print('hi')
w = widgets.Button(description="click me")
If we simply instantiated a Button object is does not have any functionality.
w
w.on_click(f)
What we need to do is give the Button object an action or function to call when the click event occurs.
The click will then call back to the function we gave it with the on_click
method.
w
Let's start by making a process that takes a short amount of time to execute - kind of like a training loop.
It iterates through something and while that is happening we can do nothing but wait. We don't see any output until the entire process is over.
from time import sleep
def slow_calculator():
res = 0
for i in range(5):
res += i*i
sleep(i)
return res
slow_calculator()
Now let's give the calculator function an optional arg that is called as a function during the process:
def slow_calculator(cb=None):
res = 0
for i in range(5):
res += i*i
sleep(i)
if cb: cb(i)
return res
And we need to write some sort of callback function that will be called:
def show_progress(num): print(f"Epoch {num}: Another one down.")
slow_calculator(cb=show_progress)
We can write the above expression in a number of different ways.
The first is using a lambda expression:
slow_calculator(cb=lambda num: print(f"Epoch {num}"))
And we can even embed another function inside our lambda:
def show_progress(epoch, exclamation): print(f"Epoch {epoch} {exclamation}")
shout = "nice, very nice."
slow_calculator(lambda x: show_progress(x, shout))
Or a simplier way of achieving the same ends would be:
def show_progress(exclamation):
_inner = lambda x: print(f"Epoch {x} {exclamation}")
return _inner
slow_calculator(show_progress("Yes Yes Yes"))
And again by this time without a lambda:
def show_progress(exclamation):
def _inner(x): print(f"Epoch {x} {exclamation}")
return _inner
slow_calculator(show_progress("Inner def"))
And we can just save this to a variable and pass it:
var = show_progress("clean and concise")
slow_calculator(var)
Partials are another very handy tool to deal with the closures
from functools import partial
def make_show_progress(exclamation, epoch):
print(f"{exclamation}! We've finished epoch {epoch}!")
slow_calculator(partial(make_show_progress, "partially now"))
It's easy to imagine how we could make a callback class which has the same functionality as above:
class ShowProgressCallback():
def __init__(self, callout):
self.callout = callout
def __call__(self, epoch):
print(f'Epoch number {epoch} {self.callout}')
cb = ShowProgressCallback("NOW A CLASS")
slow_calculator(cb=cb)
*args
and **kwargs
¶If we tweak our slow_calculator
a bit we can add callbacks to different stages of the iterator:
def slow_calculator(cb=None):
res = 0
for i in range(5):
if cb: cb.before_calc(i, res)
res += i * i
if cb: cb.after_calc(i, res)
sleep(1)
return res
class BeforeAfterCallback():
def __init__(self, before, after): self.before, self.after = before, after
def before_calc(self, *args): print(f"Here we are epoch {args[0]} about to calc and we're at {args[1]}: {self.before}")
def after_calc(self, *args): print(f"Boom! Finished epoch {args[0]} now at {args[1]}: {self.after}")
bacall = BeforeAfterCallback("Going to Calc!", "Done")
slow_calculator(bacall)
Now we can have as much control over the iteration as we want by adding callbacks into the process and having them make calculations and shape the behavior of the calculator:
def slow_calculation(cb=None):
res = 0
for i in range(5):
if cb and hasattr(cb, "before_calc"): cb.before_calc(i)
res += i *i
sleep(1)
if cb and hasattr(cb, "after_calc"):
if cb.after_calc(i, res):
print("Early stopping: returned True")
break
return res
By checking if a callback exists and using hasattr
in our loop our custom callback can now optionally not contain a before_calc
and the process will continue without throwing an error.
class PrintAfterCallback():
def after_calc(self, epoch, val):
print(f"Epoch {epoch} : {val}")
if val > 12:
return True
PC = PrintAfterCallback()
slow_calculation(PC)
The next stage of abstraction is to build the slow_calculator
as a class onto itself which contains callback as a method that checks for the appropriate method at different stages.
class SlowCalculator():
def __init__(self, cb=None):
self.cb = cb
self.res = 0
def callback(self, name, *args):
if not self.cb: return
callback = getattr(self.cb, name, None)
if callback: return callback(*args)
def calc(self):
for i in range(5):
self.callback("before_calc", i, self.res)
self.res += i*i
sleep(1)
if self.callback("after_calc", i, self.res):
print('Early Stopping')
break
return self.res
class ShowProgress():
def before_calc(self, epoch, res, *args):
print(f"Epoch: {epoch} Res: {res}")
def after_calc(self, epoch, *args):
if epoch == 3:
return True
SP = ShowProgress()
calc = SlowCalculator(cb=SP)
calc.calc()
Every custom callback we write must inherit from a Callback
parent class which defines the methods for each step in the loop where the callbacks are iterated through and the method is called.
#export
class Callback():
def begin_fit(self, learn):
self.learn = learn
return True
def after_fit(self):
return True
def begin_epoch(self, epoch):
self.epoch = epoch
return True
def begin_validate(self):
return True
def after_epoch(self):
return True
def begin_batch(self, xb, yb):
self.xb, self.yb = xb, yb
return True
def after_loss(self, loss):
self.loss = loss
return True
def after_backward(self):
return True
def after_step(self):
return True
Now for our main container the CallbackHandler
The job of the Handler is to hold everything and then be called every step in the fit process.
fit
starts with cb.begin_fit(learn)
: the learner object (which is basically a container for all of our components - model, data, optimizer, and loss function) is made an attribute of the CallbackHandler.
And everything is called from this cb.learn
, therefore, it is the core of the process.
#export
class CallbackHandler():
def __init__(self, cbs=None):
self.cbs = cbs if cbs else []
def begin_fit(self, learn):
self.learn = learn
self.in_train = True
learn.stop = False
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
def after_fit(self):
res = not self.in_train
for cb in self.cbs: res = res and cb.after_fit()
return res
def begin_epoch(self, epoch):
self.learn.model.train()
self.in_train = True
res = True
for cb in self.cbs: res = res and cb.begin_epoch(epoch)
return res
def begin_validate(self):
self.learn.model.eval()
self.in_train = False
res = True
for cb in self.cbs: res = res and cb.begin_validate()
def after_epoch(self):
res = True
for cb in self.cbs: res = res and cb.after_epoch()
return res
def begin_batch(self, xb, yb):
res = True
for cb in self.cbs: res = res and cb.begin_batch(xb,yb)
return res
def after_loss(self, loss):
res = self.in_train
for cb in self.cbs: res = res and cb.after_loss(loss)
return res
def after_backward(self):
res = True
for cb in self.cbs: res = res and cb.after_backward()
return res
def after_step(self):
res = True
for cb in self.cbs: res = res and cb.after_step()
return res
def do_stop(self):
try: return self.learn.stop
finally: self.learn.stop = False
Let's make a Test callback that stops training after a given number of iterations
class TestCallback(Callback):
def begin_fit(self, learn):
super().begin_fit(learn)
self.n_iters = 0
return True
def after_step(self):
self.n_iters += 1
print(self.n_iters)
if self.n_iters>=10: self.learn.stop = True
return True
fit(1, learn, cb= CallbackHandler([TestCallback()]))
We need to reformulate this interaction between callbacks and our learner.
The CallbackHandler method leaves a lot to be desired:
#export
import re
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
s1 = re.sub(_camel_re1, r'\1_\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
camel2snake("CallbackHandler")
Our new callback class is simplified.
It does 4 things:
_order
that pertains to the order in which callbacks are called - this is important laterself.run
when set_runner
method is called on it__getattr__
to the self.run
#export
class Callback():
_order = 0
def set_runner(self, run):
self.run = run
def __getattr__(self, k):
return getattr(self.run, k)
@property
def name(self):
name = re.sub(r'Callback$', '', self.__class__.__name__) # removes Callback from custom callback class name
return camel2snake(name or "callback")
The new runner process is managed by callbacks in a more essential way - giving them more control and thus ability of the user to customize them:
#export
class TrainEvalCallback(Callback):
def begin_fit(self):
self.run.n_epochs = 0.
self.run.n_iter = 0
def after_batch(self):
if not self.in_train: return
self.run.n_epochs += 1./self.iters
self.run.n_iter += 1
def begin_epoch(self):
self.run.n_epochs = self.epoch
self.model.train()
self.run.in_train = True
def begin_validate(self):
self.model.eval()
self.run.in_train=False
cbname = 'TrainEvalCallback'
camel2snake(cbname)
TrainEvalCallback().name
#export
from typing import *
def listify(o):
if o is None: return []
if isinstance(o, list): return o
if isinstance(o, str): return [o]
if isinstance(o, Iterable): return list(o)
return [o]
The Runner is main piece of refactoring. It is the object that runs everything. Callbacks can be easily added to it to change its behavior. Again we're attempting to effectively encapsulate our components to develop an efficient and versatile training loop.
#export
class Runner():
def __init__(self, cbs=None, cb_funcs=None):
cbs = listify(cbs)
for cbf in listify(cb_funcs):
cb = cbf
setattr(self, cb.name, cb)
cbs.append(cb)
self.stop = False
self.cbs = [TrainEvalCallback()]+cbs
@property
def opt(self):
return self.learn.opt
@property
def model(self):
return self.learn.model
@property
def loss_func(self):
return self.learn.loss_func
@property
def data(self):
return self.learn.data
def one_batch(self, xb, yb):
self.xb, self.yb = xb, yb
if self('begin_batch'): return
self.pred = self.model(self.xb)
if self('after_pred'): return
self.loss = self.loss_func(self.pred, self.yb)
if self('after_loss') or not self.in_train: return # allows us to go between training and valid
self.loss.backward()
if self('after_backward'): return
self.opt.step()
if self('after_step'): return
self.opt.zero_grad()
def all_batches(self, dl):
self.iters = len(dl)
for xb, yb in dl:
if self.stop:
print('self.stop')
break
self.one_batch(xb, yb)
self('after_batch')
self.stop=False
def fit(self, epochs, learn):
self.epochs = epochs
self.learn = learn
try:
for cb in self.cbs: cb.set_runner(self) # passes self as the runner object to each callback
if self("begin_fit"): return
for epoch in range(epochs):
self.epoch = epoch
if not self('begin_epoch'):
self.all_batches(self.data.train_dl)
with torch.no_grad():
if not self('begin_validate'):
self.all_batches(self.data.valid_dl)
if self('after_epoch'): break
finally:
self('after_fit')
self.learn = None
def __call__(self, cb_name):
for cb in sorted(self.cbs, key=lambda x: x._order):
f = getattr(cb, cb_name, None)
if f and f(): return True
return False
Here is the TestCallback again. It will wait until there have been 10 batches/iterations then set self.run.stop
to True which will stop anymore batches from being processed.
class TestCallback(Callback):
def after_step(self):
print(f"Batch: {self.n_iter}")
if self.n_iter >= 10:
print(f'Early Stopping')
self.run.stop = True
def begin_validate(self):
self.run.stop = True
run = Runner(cbs=TestCallback())
learn = Learner(*get_model(data), loss_func, data)
run.fit(1, learn)
It would be handy if we could just pass our runner metric functions that then were passed the preds and true-labels for each batch to be stored and printed at the end of every epoch.
We can do this with a StatsCallback:
#export
class AvgStats():
def __init__(self, metrics, in_train):
self.metrics = listify(metrics)
self.in_train = in_train
def reset(self):
self.tot_loss = 0.
self.count = 0
self.tot_mets = [0.]*len(self.metrics)
@property
def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
@property
def avg_stats(self): return [o/self.count for o in self.all_stats]
def __repr__(self):
if not self.count: return ""
return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
def accumulate(self, run):
bn = run.xb.shape[0]
self.tot_loss += run.loss * bn
self.count += bn
for i, m in enumerate(self.metrics):
self.tot_mets[i] += m(run.pred, run.yb) * bn
#export
class AvgStatsCallback(Callback):
def __init__(self, metrics):
self.train_stats = AvgStats(metrics, True)
self.valid_stats = AvgStats(metrics, False)
def begin_epoch(self):
self.train_stats.reset()
self.valid_stats.reset()
def after_loss(self):
stats = self.train_stats if self.in_train else self.valid_stats
with torch.no_grad(): stats.accumulate(self.run)
def after_epoch(self):
print(self.train_stats)
print(self.valid_stats)
learn = Learner(*get_model(data), loss_func, data)
run = Runner(cbs=AvgStatsCallback([accuracy]))
run.fit(2, learn)
!python notebook2script.py 04_databunch_learner_runner.ipynb