In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

DataBunch, Learner and Runner

An infinitely customizable training loop

In this section we're going to create a highly customizable training loop using callbacks.

Get Data

Get some data and put it in some Datasets

In [2]:
#export
from exp.nb_03 import *
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import ipywidgets as widgets
In [3]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
loss_func = F.cross_entropy
In [4]:
nh = 50
bs = 64
c = 10

DataBunch

Our previous fit() was overloaded with components.

Let's package some of those components using factory methods and clean up our training loop.

First, we'll bunch the data together into a DataBunch class. This will enable us to hold everything we need data-wise for training in one object - data - that can then be pulled from depending on the mode (training or validation).

In [5]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
    
    @property
    def valid_ds(self): return self.train_dl.dataset
    
    @property
    def train_ds(self): return self.train_dl.dataset

We can now define a data object which we can index into to grab batches of train or valid.

In [6]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

The @property decorator allows us to define the getter and setter properties of the dataloaders. This is a convienence function.

In [7]:
xb, yb = next(iter(data.train_dl))
In [8]:
xb.shape
Out[8]:
torch.Size([64, 784])

@property

In [9]:
class person:
    def __init__(self, name="Guest"):
        self.__name=name
    def setname(self, name):
        self.__name=name
    def getname(self):
        return self.__name
In [10]:
p1=person()
In [11]:
p1.setname('Bill')
p1.getname()
Out[11]:
'Bill'
In [12]:
class person:
    def __init__(self):
        self.__name=''
    def setname(self, name):
        print('setname() called')
        self.__name=name
    def getname(self):
        print('getname() called')
        return self.__name
    
    name=property(getname, setname)
In [13]:
p1=person()
p1.name="Steve"
setname() called
In [14]:
p1.name
getname() called
Out[14]:
'Steve'
In [15]:
class person:
    def __init__(self):
        self.__name= 'Not Implemented'
    
    @property
    def name(self):
        return self.__name
    
    @name.setter
    def name(self, value):
        self.__name=value
In [16]:
p1 = person()
p1.name
Out[16]:
'Not Implemented'
In [17]:
p1.name = 'Ted'
p1.name
Out[17]:
'Ted'

Learner

In [18]:
#export
def get_model(data, lr=0.5, nh=50):
    features = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(features, nh), nn.ReLU(), nn.Linear(nh, data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

The Learner is basically our highest level abstraction.

It is just a container - like DataBunch for the data - with no logic, for everything we need in our loop.

In [19]:
#export
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model = model
        self.opt = opt
        self.loss_func = loss_func
        self.data = data
In [20]:
learn = Learner(*get_model(data), loss_func, data)

Now let's see what our fit looks like when now that everything is encapsulated in one object.

In [21]:
def fit(learner, epochs):
    for epoch in range(epochs):
        learner.model.train()
        for xb,yb in learner.data.train_dl:
            loss = learner.loss_func(learner.model(xb), yb)
            loss.backward()
            learner.opt.step()
            learner.opt.zero_grad()
            
        learner.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learner.data.valid_dl:
                pred = learner.model(xb)
                tot_loss += learner.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(learner.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv
In [22]:
loss, acc = fit(learn, 1)
0 tensor(0.1716) tensor(0.9492)

Callback Handler version

Let's abstract away a bit more of training loop and insert some callbacks!

We'll again strip away as much as we can from the fit function. For each epoch in epochs we want to make two calls: process all batches in the training set and all batches in the validation set.

Those calls will pass along the appropriate data and then they will be processed by all_batches and then one_batch.

This is Factory Method refactoring.

Without the callbacks it will look something like this:

def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()


def all_batches(dl):
    for b in dl: 
        one_batch(*b)


def fit():
    for epoch in range(epochs):
        all_batches(learn.data.train_dl)

        with torch.no_grad():
            all_batches(learn.data.valid_dl)
In [23]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb, yb): return # if returns False stop
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()
    
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return # if do_stop is True stop the loop

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return # the callbackhandler gets passed the learner 
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue 
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

Now the key to this refactoring is that there is this cb object that gets passed to fit and is used everywhere. This is going to be the CallbackHandler

Primer on Callbacks

Callbacks are used frequently when doing backend development. Webframeworks built Node.js are often filled with callbacks.

They are less often used in Python ML development. So let's make some simple examples to demonstrate the idea.

Click callback

In [24]:
def f(o): print('hi')
In [25]:
w = widgets.Button(description="click me")

If we simply instantiated a Button object is does not have any functionality.

In [26]:
w
In [27]:
w.on_click(f)

What we need to do is give the Button object an action or function to call when the click event occurs.

The click will then call back to the function we gave it with the on_click method.

In [28]:
w

Slow Calculator

Let's start by making a process that takes a short amount of time to execute - kind of like a training loop.

It iterates through something and while that is happening we can do nothing but wait. We don't see any output until the entire process is over.

In [29]:
from time import sleep

def slow_calculator():
    res = 0
    for i in range(5):
        res += i*i
        sleep(i)
    return res
In [30]:
slow_calculator()
Out[30]:
30

Now let's give the calculator function an optional arg that is called as a function during the process:

In [31]:
def slow_calculator(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(i)
        if cb: cb(i)
    return res

And we need to write some sort of callback function that will be called:

In [32]:
def show_progress(num): print(f"Epoch {num}: Another one down.")
In [33]:
slow_calculator(cb=show_progress)
Epoch 0: Another one down.
Epoch 1: Another one down.
Epoch 2: Another one down.
Epoch 3: Another one down.
Epoch 4: Another one down.
Out[33]:
30

Lambdas and Partials

We can write the above expression in a number of different ways.

The first is using a lambda expression:

In [34]:
slow_calculator(cb=lambda num: print(f"Epoch {num}"))
Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Out[34]:
30

And we can even embed another function inside our lambda:

In [35]:
def show_progress(epoch, exclamation): print(f"Epoch {epoch} {exclamation}")
In [36]:
shout = "nice, very nice."
slow_calculator(lambda x: show_progress(x, shout))
Epoch 0 nice, very nice.
Epoch 1 nice, very nice.
Epoch 2 nice, very nice.
Epoch 3 nice, very nice.
Epoch 4 nice, very nice.
Out[36]:
30

Or a simplier way of achieving the same ends would be:

In [37]:
def show_progress(exclamation):
    _inner = lambda x: print(f"Epoch {x} {exclamation}")
    return _inner
In [38]:
slow_calculator(show_progress("Yes Yes Yes"))
Epoch 0 Yes Yes Yes
Epoch 1 Yes Yes Yes
Epoch 2 Yes Yes Yes
Epoch 3 Yes Yes Yes
Epoch 4 Yes Yes Yes
Out[38]:
30

And again by this time without a lambda:

In [39]:
def show_progress(exclamation):
    def _inner(x): print(f"Epoch {x} {exclamation}")
    return _inner
In [40]:
slow_calculator(show_progress("Inner def"))
Epoch 0 Inner def
Epoch 1 Inner def
Epoch 2 Inner def
Epoch 3 Inner def
Epoch 4 Inner def
Out[40]:
30

And we can just save this to a variable and pass it:

In [41]:
var = show_progress("clean and concise")
slow_calculator(var)
Epoch 0 clean and concise
Epoch 1 clean and concise
Epoch 2 clean and concise
Epoch 3 clean and concise
Epoch 4 clean and concise
Out[41]:
30

Func Tools and Partial

Partials are another very handy tool to deal with the closures

In [42]:
from functools import partial
In [43]:
def make_show_progress(exclamation, epoch):
    print(f"{exclamation}! We've finished epoch {epoch}!")
In [44]:
slow_calculator(partial(make_show_progress, "partially now"))
partially now! We've finished epoch 0!
partially now! We've finished epoch 1!
partially now! We've finished epoch 2!
partially now! We've finished epoch 3!
partially now! We've finished epoch 4!
Out[44]:
30

Callbacks as Classes

It's easy to imagine how we could make a callback class which has the same functionality as above:

In [45]:
class ShowProgressCallback():
    def __init__(self, callout):
        self.callout = callout
    def __call__(self, epoch):
        print(f'Epoch number {epoch} {self.callout}')
In [46]:
cb = ShowProgressCallback("NOW A CLASS")
In [47]:
slow_calculator(cb=cb)
Epoch number 0 NOW A CLASS
Epoch number 1 NOW A CLASS
Epoch number 2 NOW A CLASS
Epoch number 3 NOW A CLASS
Epoch number 4 NOW A CLASS
Out[47]:
30

Multiple callback funcs; *args and **kwargs

If we tweak our slow_calculator a bit we can add callbacks to different stages of the iterator:

In [48]:
def slow_calculator(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i, res)
        res += i * i
        if cb: cb.after_calc(i, res)
        sleep(1)
    return res
In [49]:
class BeforeAfterCallback():
    def __init__(self, before, after): self.before, self.after = before, after
    def before_calc(self, *args): print(f"Here we are epoch {args[0]} about to calc and we're at {args[1]}: {self.before}")
    def after_calc(self, *args): print(f"Boom! Finished epoch {args[0]} now at {args[1]}: {self.after}")
In [50]:
bacall = BeforeAfterCallback("Going to Calc!", "Done")
In [51]:
slow_calculator(bacall)
Here we are epoch 0 about to calc and we're at 0: Going to Calc!
Boom! Finished epoch 0 now at 0: Done
Here we are epoch 1 about to calc and we're at 0: Going to Calc!
Boom! Finished epoch 1 now at 1: Done
Here we are epoch 2 about to calc and we're at 1: Going to Calc!
Boom! Finished epoch 2 now at 5: Done
Here we are epoch 3 about to calc and we're at 5: Going to Calc!
Boom! Finished epoch 3 now at 14: Done
Here we are epoch 4 about to calc and we're at 14: Going to Calc!
Boom! Finished epoch 4 now at 30: Done
Out[51]:
30

Modifying behavior

Now we can have as much control over the iteration as we want by adding callbacks into the process and having them make calculations and shape the behavior of the calculator:

In [52]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb, "before_calc"): cb.before_calc(i)
        res += i *i
        sleep(1)
        if cb and hasattr(cb, "after_calc"): 
            if cb.after_calc(i, res):
                print("Early stopping: returned True")
                break
    return res

By checking if a callback exists and using hasattr in our loop our custom callback can now optionally not contain a before_calc and the process will continue without throwing an error.

In [53]:
class PrintAfterCallback():
    def after_calc(self, epoch, val):
        print(f"Epoch {epoch} : {val}")
        if val > 12:
            return True
In [54]:
PC = PrintAfterCallback()
In [55]:
slow_calculation(PC)
Epoch 0 : 0
Epoch 1 : 1
Epoch 2 : 5
Epoch 3 : 14
Early stopping: returned True
Out[55]:
14

The next stage of abstraction is to build the slow_calculator as a class onto itself which contains callback as a method that checks for the appropriate method at different stages.

In [56]:
class SlowCalculator():
    def __init__(self, cb=None):
        self.cb = cb
        self.res = 0
    
    def callback(self, name, *args):
        if not self.cb: return
        callback = getattr(self.cb, name, None) 
        if callback: return callback(*args)
    
    def calc(self):
        for i in range(5):
            self.callback("before_calc", i, self.res)
            self.res += i*i
            sleep(1)
            if self.callback("after_calc", i, self.res):
                print('Early Stopping')
                break
        return self.res
In [57]:
class ShowProgress():
    def before_calc(self, epoch, res, *args):
        print(f"Epoch: {epoch} Res: {res}")
    def after_calc(self, epoch, *args):
        if epoch == 3:
            return True
In [58]:
SP = ShowProgress()
In [59]:
calc = SlowCalculator(cb=SP)
In [60]:
calc.calc()
Epoch: 0 Res: 0
Epoch: 1 Res: 0
Epoch: 2 Res: 1
Epoch: 3 Res: 5
Early Stopping
Out[60]:
14

Callback Parent Class

Every custom callback we write must inherit from a Callback parent class which defines the methods for each step in the loop where the callbacks are iterated through and the method is called.

In [61]:
#export
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self):
        return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self):
        return True
    def after_epoch(self):
        return True
    def begin_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self):
        return True
    def after_step(self):
        return True

CallbackHandler Class

Now for our main container the CallbackHandler

The job of the Handler is to hold everything and then be called every step in the fit process.

fit starts with cb.begin_fit(learn): the learner object (which is basically a container for all of our components - model, data, optimizer, and loss function) is made an attribute of the CallbackHandler.

And everything is called from this cb.learn, therefore, it is the core of the process.

In [62]:
#export
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs if cbs else []
        
    def begin_fit(self, learn):
        self.learn = learn
        self.in_train = True
        learn.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learn)
        return res
    
    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
    
    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb,yb)
        return res
    
    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res
    
    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try: return self.learn.stop
        finally: self.learn.stop = False

TestCallback

Let's make a Test callback that stops training after a given number of iterations

In [63]:
class TestCallback(Callback):
    
    def begin_fit(self, learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
    
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.learn.stop = True
        return True
In [64]:
fit(1, learn, cb= CallbackHandler([TestCallback()]))
1
2
3
4
5
6
7
8
9
10

Runner v1.0

We need to reformulate this interaction between callbacks and our learner.

The CallbackHandler method leaves a lot to be desired:

  • there is a lot of repeat code
  • every callback has to inherit from the parent class
  • callbackhandler is being passed around and called repeatedly
  • it's kind of bulky and unintuitive
In [65]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')


def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()
In [66]:
camel2snake("CallbackHandler")
Out[66]:
'callback_handler'

New Callback class

Our new callback class is simplified.

It does 4 things:

  1. Has a class attribute called _order that pertains to the order in which callbacks are called - this is important later
  2. Sets a self.run when set_runner method is called on it
  3. Diverts __getattr__ to the self.run
  4. Converts the name of the callback to snake case
In [67]:
#export
class Callback():
    
    _order = 0
    
    def set_runner(self, run):
        self.run = run
        
    def __getattr__(self, k): 
        return getattr(self.run, k)
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__) # removes Callback from custom callback class name
        return camel2snake(name or "callback")

TrainEval Callback

The new runner process is managed by callbacks in a more essential way - giving them more control and thus ability of the user to customize them:

In [68]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0.
        self.run.n_iter = 0
    
    def after_batch(self):
        if not self.in_train: return 
        self.run.n_epochs += 1./self.iters
        self.run.n_iter += 1
        
    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True
    
    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False
In [69]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)
Out[69]:
'train_eval_callback'
In [70]:
TrainEvalCallback().name
Out[70]:
'train_eval'
In [71]:
#export
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

Runner Class

The Runner is main piece of refactoring. It is the object that runs everything. Callbacks can be easily added to it to change its behavior. Again we're attempting to effectively encapsulate our components to develop an efficient and versatile training loop.

In [72]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf
            setattr(self, cb.name, cb) 
            cbs.append(cb)
        self.stop = False
        self.cbs = [TrainEvalCallback()]+cbs
        
    @property
    def opt(self):
        return self.learn.opt
    @property
    def model(self):
        return self.learn.model
    @property
    def loss_func(self):
        return self.learn.loss_func
    @property
    def data(self):
        return self.learn.data
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('begin_batch'): return 
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return # allows us to go between training and valid
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()
    
    def all_batches(self, dl):
        self.iters = len(dl)
        
        for xb, yb in dl:
            if self.stop: 
                print('self.stop')
                break
            self.one_batch(xb, yb)
            self('after_batch')
        
        self.stop=False
        
    
    def fit(self, epochs, learn):
        self.epochs = epochs
        self.learn = learn
        
        try:
            for cb in self.cbs: cb.set_runner(self) # passes self as the runner object to each callback
            if self("begin_fit"): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): 
                    self.all_batches(self.data.train_dl)
                
                with torch.no_grad():
                    if not self('begin_validate'): 
                        self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
        
        finally:
            self('after_fit')
            self.learn = None
            
            
    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order): 
            f = getattr(cb, cb_name, None) 
            if f and f(): return True 
        return False 

Early Stopping

Here is the TestCallback again. It will wait until there have been 10 batches/iterations then set self.run.stop to True which will stop anymore batches from being processed.

In [73]:
class TestCallback(Callback):
    def after_step(self):
        print(f"Batch: {self.n_iter}")
        if self.n_iter >= 10:
            print(f'Early Stopping')
            self.run.stop = True
    def begin_validate(self):
        self.run.stop = True
In [74]:
run = Runner(cbs=TestCallback())
learn = Learner(*get_model(data), loss_func, data)
In [75]:
run.fit(1, learn)
Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Early Stopping
self.stop
self.stop

Metrics

It would be handy if we could just pass our runner metric functions that then were passed the preds and true-labels for each batch to be stored and printed at the end of every epoch.

We can do this with a StatsCallback:

In [76]:
#export
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train
    
    def reset(self):
        self.tot_loss = 0.
        self.count = 0
        self.tot_mets = [0.]*len(self.metrics)
    
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
    
    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
    
    
In [77]:
#export
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
    
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
            
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)
In [78]:
learn = Learner(*get_model(data), loss_func, data)
In [79]:
run = Runner(cbs=AvgStatsCallback([accuracy]))
In [80]:
run.fit(2, learn)
train: [0.30912541015625, tensor(0.9046)]
valid: [0.256246044921875, tensor(0.9214)]
train: [0.141904765625, tensor(0.9574)]
valid: [0.204003759765625, tensor(0.9402)]
In [ ]:
 
In [81]:
!python notebook2script.py 04_databunch_learner_runner.ipynb
Converted 04_databunch_learner_runner.ipynb to exp\nb_04.py