In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

DataBunch, Learner and Runner

An infinitely customizable training loop

In this section we're going to create a highly customizable training loop using callbacks.

Get Data

Get some data and put it in some Datasets

In [2]:
from exp.nb_03 import *
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import ipywidgets as widgets
In [3]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
loss_func = F.cross_entropy
In [4]:
nh = 50
bs = 64
c = 10


Our previous fit() was overloaded with components.

Let's package some of those components using factory methods and clean up our training loop.

First, we'll bunch the data together into a DataBunch class. This will enable us to hold everything we need data-wise for training in one object - data - that can then be pulled from depending on the mode (training or validation).

In [5]:
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
    def valid_ds(self): return self.train_dl.dataset
    def train_ds(self): return self.train_dl.dataset

We can now define a data object which we can index into to grab batches of train or valid.

In [6]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

The @property decorator allows us to define the getter and setter properties of the dataloaders. This is a convienence function.

In [7]:
xb, yb = next(iter(data.train_dl))
In [8]:
torch.Size([64, 784])


In [9]:
class person:
    def __init__(self, name="Guest"):
    def setname(self, name):
    def getname(self):
        return self.__name
In [10]:
In [11]:
In [12]:
class person:
    def __init__(self):
    def setname(self, name):
        print('setname() called')
    def getname(self):
        print('getname() called')
        return self.__name
    name=property(getname, setname)
In [13]:
setname() called
In [14]:
getname() called
In [15]:
class person:
    def __init__(self):
        self.__name= 'Not Implemented'
    def name(self):
        return self.__name
    def name(self, value):
In [16]:
p1 = person()
'Not Implemented'
In [17]: = 'Ted'


In [18]:
def get_model(data, lr=0.5, nh=50):
    features = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(features, nh), nn.ReLU(), nn.Linear(nh, data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

The Learner is basically our highest level abstraction.

It is just a container - like DataBunch for the data - with no logic, for everything we need in our loop.

In [19]:
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model = model
        self.opt = opt
        self.loss_func = loss_func = data
In [20]:
learn = Learner(*get_model(data), loss_func, data)

Now let's see what our fit looks like when now that everything is encapsulated in one object.

In [21]:
def fit(learner, epochs):
    for epoch in range(epochs):
        for xb,yb in
            loss = learner.loss_func(learner.model(xb), yb)
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in
                pred = learner.model(xb)
                tot_loss += learner.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv
In [22]:
loss, acc = fit(learn, 1)
0 tensor(0.1716) tensor(0.9492)

Callback Handler version

Let's abstract away a bit more of training loop and insert some callbacks!

We'll again strip away as much as we can from the fit function. For each epoch in epochs we want to make two calls: process all batches in the training set and all batches in the validation set.

Those calls will pass along the appropriate data and then they will be processed by all_batches and then one_batch.

This is Factory Method refactoring.

Without the callbacks it will look something like this:

def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)

def all_batches(dl):
    for b in dl: 

def fit():
    for epoch in range(epochs):

        with torch.no_grad():
In [23]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb, yb): return # if returns False stop
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return # if do_stop is True stop the loop

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return # the callbackhandler gets passed the learner 
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue 
        all_batches(, cb)
        if cb.begin_validate():
            with torch.no_grad(): all_batches(, cb)
        if cb.do_stop() or not cb.after_epoch(): break

Now the key to this refactoring is that there is this cb object that gets passed to fit and is used everywhere. This is going to be the CallbackHandler

Primer on Callbacks

Callbacks are used frequently when doing backend development. Webframeworks built Node.js are often filled with callbacks.

They are less often used in Python ML development. So let's make some simple examples to demonstrate the idea.

Click callback

In [24]:
def f(o): print('hi')
In [25]:
w = widgets.Button(description="click me")

If we simply instantiated a Button object is does not have any functionality.

In [26]:
In [27]:

What we need to do is give the Button object an action or function to call when the click event occurs.

The click will then call back to the function we gave it with the on_click method.

In [28]:

Slow Calculator

Let's start by making a process that takes a short amount of time to execute - kind of like a training loop.

It iterates through something and while that is happening we can do nothing but wait. We don't see any output until the entire process is over.

In [29]:
from time import sleep

def slow_calculator():
    res = 0
    for i in range(5):
        res += i*i
    return res
In [30]:

Now let's give the calculator function an optional arg that is called as a function during the process:

In [31]:
def slow_calculator(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        if cb: cb(i)
    return res

And we need to write some sort of callback function that will be called:

In [32]:
def show_progress(num): print(f"Epoch {num}: Another one down.")
In [33]:
Epoch 0: Another one down.
Epoch 1: Another one down.
Epoch 2: Another one down.
Epoch 3: Another one down.
Epoch 4: Another one down.

Lambdas and Partials

We can write the above expression in a number of different ways.

The first is using a lambda expression:

In [34]:
slow_calculator(cb=lambda num: print(f"Epoch {num}"))
Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4

And we can even embed another function inside our lambda:

In [35]:
def show_progress(epoch, exclamation): print(f"Epoch {epoch} {exclamation}")
In [36]:
shout = "nice, very nice."
slow_calculator(lambda x: show_progress(x, shout))
Epoch 0 nice, very nice.
Epoch 1 nice, very nice.
Epoch 2 nice, very nice.
Epoch 3 nice, very nice.
Epoch 4 nice, very nice.

Or a simplier way of achieving the same ends would be:

In [37]:
def show_progress(exclamation):
    _inner = lambda x: print(f"Epoch {x} {exclamation}")
    return _inner
In [38]:
slow_calculator(show_progress("Yes Yes Yes"))
Epoch 0 Yes Yes Yes
Epoch 1 Yes Yes Yes
Epoch 2 Yes Yes Yes
Epoch 3 Yes Yes Yes
Epoch 4 Yes Yes Yes

And again by this time without a lambda:

In [39]:
def show_progress(exclamation):
    def _inner(x): print(f"Epoch {x} {exclamation}")
    return _inner
In [40]:
slow_calculator(show_progress("Inner def"))
Epoch 0 Inner def
Epoch 1 Inner def
Epoch 2 Inner def
Epoch 3 Inner def
Epoch 4 Inner def

And we can just save this to a variable and pass it:

In [41]:
var = show_progress("clean and concise")
Epoch 0 clean and concise
Epoch 1 clean and concise
Epoch 2 clean and concise
Epoch 3 clean and concise
Epoch 4 clean and concise

Func Tools and Partial

Partials are another very handy tool to deal with the closures

In [42]:
from functools import partial
In [43]:
def make_show_progress(exclamation, epoch):
    print(f"{exclamation}! We've finished epoch {epoch}!")
In [44]:
slow_calculator(partial(make_show_progress, "partially now"))
partially now! We've finished epoch 0!
partially now! We've finished epoch 1!
partially now! We've finished epoch 2!
partially now! We've finished epoch 3!
partially now! We've finished epoch 4!

Callbacks as Classes

It's easy to imagine how we could make a callback class which has the same functionality as above:

In [45]:
class ShowProgressCallback():
    def __init__(self, callout):
        self.callout = callout
    def __call__(self, epoch):
        print(f'Epoch number {epoch} {self.callout}')
In [46]:
cb = ShowProgressCallback("NOW A CLASS")
In [47]:
Epoch number 0 NOW A CLASS
Epoch number 1 NOW A CLASS
Epoch number 2 NOW A CLASS
Epoch number 3 NOW A CLASS
Epoch number 4 NOW A CLASS

Multiple callback funcs; *args and **kwargs

If we tweak our slow_calculator a bit we can add callbacks to different stages of the iterator:

In [48]:
def slow_calculator(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i, res)
        res += i * i
        if cb: cb.after_calc(i, res)
    return res
In [49]:
class BeforeAfterCallback():
    def __init__(self, before, after): self.before, self.after = before, after
    def before_calc(self, *args): print(f"Here we are epoch {args[0]} about to calc and we're at {args[1]}: {self.before}")
    def after_calc(self, *args): print(f"Boom! Finished epoch {args[0]} now at {args[1]}: {self.after}")
In [50]:
bacall = BeforeAfterCallback("Going to Calc!", "Done")
In [51]:
Here we are epoch 0 about to calc and we're at 0: Going to Calc!
Boom! Finished epoch 0 now at 0: Done
Here we are epoch 1 about to calc and we're at 0: Going to Calc!
Boom! Finished epoch 1 now at 1: Done
Here we are epoch 2 about to calc and we're at 1: Going to Calc!
Boom! Finished epoch 2 now at 5: Done
Here we are epoch 3 about to calc and we're at 5: Going to Calc!
Boom! Finished epoch 3 now at 14: Done
Here we are epoch 4 about to calc and we're at 14: Going to Calc!
Boom! Finished epoch 4 now at 30: Done

Modifying behavior

Now we can have as much control over the iteration as we want by adding callbacks into the process and having them make calculations and shape the behavior of the calculator:

In [52]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb, "before_calc"): cb.before_calc(i)
        res += i *i
        if cb and hasattr(cb, "after_calc"): 
            if cb.after_calc(i, res):
                print("Early stopping: returned True")
    return res

By checking if a callback exists and using hasattr in our loop our custom callback can now optionally not contain a before_calc and the process will continue without throwing an error.

In [53]:
class PrintAfterCallback():
    def after_calc(self, epoch, val):
        print(f"Epoch {epoch} : {val}")
        if val > 12:
            return True
In [54]:
PC = PrintAfterCallback()
In [55]:
Epoch 0 : 0
Epoch 1 : 1
Epoch 2 : 5
Epoch 3 : 14
Early stopping: returned True

The next stage of abstraction is to build the slow_calculator as a class onto itself which contains callback as a method that checks for the appropriate method at different stages.

In [56]:
class SlowCalculator():
    def __init__(self, cb=None):
        self.cb = cb
        self.res = 0
    def callback(self, name, *args):
        if not self.cb: return
        callback = getattr(self.cb, name, None) 
        if callback: return callback(*args)
    def calc(self):
        for i in range(5):
            self.callback("before_calc", i, self.res)
            self.res += i*i
            if self.callback("after_calc", i, self.res):
                print('Early Stopping')
        return self.res
In [57]:
class ShowProgress():
    def before_calc(self, epoch, res, *args):
        print(f"Epoch: {epoch} Res: {res}")
    def after_calc(self, epoch, *args):
        if epoch == 3:
            return True
In [58]:
SP = ShowProgress()
In [59]:
calc = SlowCalculator(cb=SP)
In [60]:
Epoch: 0 Res: 0
Epoch: 1 Res: 0
Epoch: 2 Res: 1
Epoch: 3 Res: 5
Early Stopping

Callback Parent Class

Every custom callback we write must inherit from a Callback parent class which defines the methods for each step in the loop where the callbacks are iterated through and the method is called.

In [61]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self):
        return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self):
        return True
    def after_epoch(self):
        return True
    def begin_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self):
        return True
    def after_step(self):
        return True

CallbackHandler Class

Now for our main container the CallbackHandler

The job of the Handler is to hold everything and then be called every step in the fit process.

fit starts with cb.begin_fit(learn): the learner object (which is basically a container for all of our components - model, data, optimizer, and loss function) is made an attribute of the CallbackHandler.

And everything is called from this cb.learn, therefore, it is the core of the process.

In [62]:
class CallbackHandler():
    def __init__(self, cbs=None): = cbs if cbs else []
    def begin_fit(self, learn):
        self.learn = learn
        self.in_train = True
        learn.stop = False
        res = True
        for cb in res = res and cb.begin_fit(learn)
        return res
    def after_fit(self):
        res = not self.in_train
        for cb in res = res and cb.after_fit()
        return res
    def begin_epoch(self, epoch):
        self.in_train = True
        res = True
        for cb in res = res and cb.begin_epoch(epoch)
        return res
    def begin_validate(self):
        self.in_train = False
        res = True
        for cb in res = res and cb.begin_validate()
    def after_epoch(self):
        res = True
        for cb in res = res and cb.after_epoch()
        return res
    def begin_batch(self, xb, yb):
        res = True
        for cb in res = res and cb.begin_batch(xb,yb)
        return res
    def after_loss(self, loss):
        res = self.in_train
        for cb in res = res and cb.after_loss(loss)
        return res
    def after_backward(self):
        res = True
        for cb in res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in res = res and cb.after_step()
        return res
    def do_stop(self):
        try: return self.learn.stop
        finally: self.learn.stop = False


Let's make a Test callback that stops training after a given number of iterations

In [63]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        self.n_iters = 0
        return True
    def after_step(self):
        self.n_iters += 1
        if self.n_iters>=10: self.learn.stop = True
        return True
In [64]:
fit(1, learn, cb= CallbackHandler([TestCallback()]))

Runner v1.0

We need to reformulate this interaction between callbacks and our learner.

The CallbackHandler method leaves a lot to be desired:

  • there is a lot of repeat code
  • every callback has to inherit from the parent class
  • callbackhandler is being passed around and called repeatedly
  • it's kind of bulky and unintuitive
In [65]:
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()
In [66]:

New Callback class

Our new callback class is simplified.

It does 4 things:

  1. Has a class attribute called _order that pertains to the order in which callbacks are called - this is important later
  2. Sets a when set_runner method is called on it
  3. Diverts __getattr__ to the
  4. Converts the name of the callback to snake case
In [67]:
class Callback():
    _order = 0
    def set_runner(self, run): = run
    def __getattr__(self, k): 
        return getattr(, k)
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__) # removes Callback from custom callback class name
        return camel2snake(name or "callback")

TrainEval Callback

The new runner process is managed by callbacks in a more essential way - giving them more control and thus ability of the user to customize them:

In [68]:
class TrainEvalCallback(Callback):
    def begin_fit(self): = 0. = 0
    def after_batch(self):
        if not self.in_train: return += 1./self.iters += 1
    def begin_epoch(self): = self.epoch
        self.model.train() = True
    def begin_validate(self):
In [69]:
cbname = 'TrainEvalCallback'
In [70]:
In [71]:
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

Runner Class

The Runner is main piece of refactoring. It is the object that runs everything. Callbacks can be easily added to it to change its behavior. Again we're attempting to effectively encapsulate our components to develop an efficient and versatile training loop.

In [72]:
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf
            setattr(self,, cb) 
        self.stop = False = [TrainEvalCallback()]+cbs
    def opt(self):
        return self.learn.opt
    def model(self):
        return self.learn.model
    def loss_func(self):
        return self.learn.loss_func
    def data(self):
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('begin_batch'): return 
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return # allows us to go between training and valid
        if self('after_backward'): return
        if self('after_step'): return
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop: 
            self.one_batch(xb, yb)
    def fit(self, epochs, learn):
        self.epochs = epochs
        self.learn = learn
            for cb in cb.set_runner(self) # passes self as the runner object to each callback
            if self("begin_fit"): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): 
                with torch.no_grad():
                    if not self('begin_validate'): 
                if self('after_epoch'): break
            self.learn = None
    def __call__(self, cb_name):
        for cb in sorted(, key=lambda x: x._order): 
            f = getattr(cb, cb_name, None) 
            if f and f(): return True 
        return False 

Early Stopping

Here is the TestCallback again. It will wait until there have been 10 batches/iterations then set to True which will stop anymore batches from being processed.

In [73]:
class TestCallback(Callback):
    def after_step(self):
        print(f"Batch: {self.n_iter}")
        if self.n_iter >= 10:
            print(f'Early Stopping')
   = True
    def begin_validate(self): = True
In [74]:
run = Runner(cbs=TestCallback())
learn = Learner(*get_model(data), loss_func, data)
In [75]:, learn)
Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Early Stopping


It would be handy if we could just pass our runner metric functions that then were passed the preds and true-labels for each batch to be stored and printed at the end of every epoch.

We can do this with a StatsCallback:

In [76]:
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train
    def reset(self):
        self.tot_loss = 0.
        self.count = 0
        self.tot_mets = [0.]*len(self.metrics)
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
In [77]:
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)
    def begin_epoch(self):
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(
    def after_epoch(self):
In [78]:
learn = Learner(*get_model(data), loss_func, data)
In [79]:
run = Runner(cbs=AvgStatsCallback([accuracy]))
In [80]:, learn)
train: [0.30912541015625, tensor(0.9046)]
valid: [0.256246044921875, tensor(0.9214)]
train: [0.141904765625, tensor(0.9574)]
valid: [0.204003759765625, tensor(0.9402)]
In [ ]:
In [81]:
!python 04_databunch_learner_runner.ipynb
Converted 04_databunch_learner_runner.ipynb to exp\