In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

CNNs, CUDA and Hooks

Training convolutional neural networks with CUDA and Pytorch hooks

In [2]:
#export
from exp.nb_06 import *
In [3]:
# import torch.nn.functional as F
# import torch.nn as nn
# import torch.optim as optim

Get Data

In [4]:
x_train,y_train,x_valid,y_valid = get_data()
In [5]:
#export
def normalize_to(train, valid):
    m,s = train.mean(), train.std()
    return normalize(train, m, s), normalize(valid, m,s)
In [6]:
x_train, x_valid = normalize_to(x_train, x_valid)
In [7]:
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,512
c = y_train.max().item()+1
loss_func = F.cross_entropy
In [8]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)
In [9]:
data.train_ds.x.mean(), data.train_ds.x.std()
Out[9]:
(tensor(-6.2598e-06), tensor(1.))

Basic CNN

We're going to implement a basic CNN using a some 2d conv layers.

Lambda Class

If we want to make a func and put it into nn.Sequential it needs to be a nn.Module.

To do this we'll use a Lambda class that takes a function, inherits and initializes from nn.Module and then calls the function on the forward pass:

In [10]:
#export
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    def forward(self, x):
        return self.func(x)

Get CNN Model

First step is the use the Lambda class above to reshape our batches into shape:

BATCH x CHANNEL x HEIGHT x WIDTH

In [11]:
#export
def flatten(x):
    return x.view(x.shape[0], -1)

def mnist_resize(x):
    return x.view(-1, 1, 28, 28)
In [12]:
xb, yb = next(iter(data.train_dl))
xb.shape
Out[12]:
torch.Size([512, 784])
In [13]:
nb = mnist_resize(xb)
nb.shape
Out[13]:
torch.Size([512, 1, 28, 28])
In [14]:
def get_cnn_model(data):
    return nn.Sequential(
        Lambda(mnist_resize),
        nn.Conv2d( 1,  8, 5, padding=2, stride=2), nn.ReLU(), # stride 2 reduces image 14x14
        nn.Conv2d( 8, 16, 3, padding=1, stride=2), nn.ReLU(), # stride 2 reduces image 7x7
        nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.ReLU(), # stride 2 reduces image 4x4
        nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.ReLU(), # stride 2 reduces image 2x2
        nn.AdaptiveAvgPool2d(1),
        Lambda(flatten),
        nn.Linear(32, data.c)
    )
In [15]:
model = get_cnn_model(data)
In [16]:
opt = optim.SGD(model.parameters(), lr=0.4)
learn = Learner(model, opt, loss_func, data)
run = Runner(cb_funcs = [Recorder, partial(AvgStatsCallback, accuracy)])
In [17]:
%time run.fit(1, learn)
train: [1.88377203125, tensor(0.3417)]
valid: [0.61258681640625, tensor(0.8151)]
Wall time: 11.6 s

Great. It appears to be working but it took more than 11 seconds to run.

We'll need to throw it on the GPU to optimize the matrix multiplication.

CUDA

Pytorch offers a few ways to set which GPU to work with.

In [18]:
device = torch.device('cuda', 0)
In [19]:
torch.cuda.set_device(device)
In [20]:
#export
class CudaCallback(Callback):
    _order = 1
    def begin_fit(self):
        self.model = self.model.cuda()
    def begin_batch(self):
        self.run.xb, self.run.yb = self.xb.cuda(), self.yb.cuda()
In [21]:
model = get_cnn_model(data)
In [22]:
cbfs = [Recorder, CudaCallback, partial(AvgStatsCallback, accuracy)]
In [23]:
opt = optim.SGD(model.parameters(), lr=0.4)
learn = Learner(model, opt, loss_func, data)
run = Runner(cb_funcs=cbfs)
In [24]:
%time run.fit(3, learn)
train: [1.91082140625, tensor(0.3575, device='cuda:0')]
valid: [0.501539990234375, tensor(0.8400, device='cuda:0')]
train: [0.370878984375, tensor(0.8860, device='cuda:0')]
valid: [0.19286480712890625, tensor(0.9425, device='cuda:0')]
train: [0.1835596484375, tensor(0.9445, device='cuda:0')]
valid: [0.15307530517578125, tensor(0.9559, device='cuda:0')]
Wall time: 5.43 s

Much better. Now we can do 3x as many epochs in less than half the time.

Refactor CNN

We'll want to eventually make deeper models so let's find some ways to build layers quickly and easily.

In [25]:
#export
def conv2d(ni, nf, ks=3, stride=2):
    return nn.Sequential(
        nn.Conv2d(ni, nf,  ks, padding=ks//2, stride=stride), nn.ReLU())
In [26]:
conv2d(1, 8)
Out[26]:
Sequential(
  (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (1): ReLU()
)

And instead of using the Lambda class above let's write a general callback that transforms the batch with a given transformation function.

In [27]:
#export
class BatchTransformXCallback(Callback):
    _order = 2
    def __init__(self, tfm): 
        self.tfm = tfm
    
    def begin_batch(self):
        self.run.xb = self.tfm(self.xb)

The batch size might vary because not all batches are equal. So we'll need to be able to vary the first dimension.

After an hour spent troubleshooting my model and attempting to determine where my bug was I finally realized it was that I forgot to put the comma after the one in this transform function.

The Runner hides errors that would otherwise be easy to spot with the try finally control flow style.

In [28]:
#export
def view_tfm(*size):
    def _inner(x): return x.view(*((-1,)+size))
    return _inner
In [29]:
mnist_view = view_tfm(1,28,28) 
In [30]:
mnist_view(xb).shape
Out[30]:
torch.Size([512, 1, 28, 28])

CNN Layer Generator

In [31]:
nfs = [8, 16, 32, 32]
In [32]:
def get_cnn_layers(data, nfs):
    nfs = [1] + nfs
    return [conv2d(nfs[i], nfs[i+1], 5 if i==0 else 3) for i in range(len(nfs)-1)
           ] + [nn.AdaptiveAvgPool2d(1), Lambda(flatten), nn.Linear(nfs[-1], data.c)]

Why is the first layer's kernel size 5x5 and the rest 3x3?

We have 8 filters in the first layer --> nfs[0].

If our kernel size is 3x3 the output of each filter's convolution over the input (in our case its 1 x 28 x 28) will be

In [ ]:
 
In [ ]:
 

Now we'll use the layer builder to make the layers and extract them into a nn.Sequential

In [33]:
def get_cnn_model(data, nfs): 
    return nn.Sequential(*get_cnn_layers(data, nfs))

Lastly, let's bundle everything we need:

In [34]:
#export
def get_runner(model, data, lr=0.6, cbs=None, opt_func=None, loss_func = F.cross_entropy):
    if opt_func is None: opt_func = optim.SGD
    opt = opt_func(model.parameters(), lr=lr)
    learn = Learner(model, opt, loss_func, data)
    return learn, Runner(cb_funcs=listify(cbs))
In [35]:
cbfs = [Recorder, CudaCallback, partial(AvgStatsCallback, accuracy), partial(BatchTransformXCallback, mnist_view)]
In [36]:
model = get_cnn_model(data, nfs)
learn, run = get_runner(model, data, lr=0.3, cbs=cbfs)
In [37]:
run.fit(3, learn)
train: [2.096260625, tensor(0.2690, device='cuda:0')]
valid: [0.99009228515625, tensor(0.6910, device='cuda:0')]
train: [0.5943169921875, tensor(0.8186, device='cuda:0')]
valid: [0.2552780029296875, tensor(0.9233, device='cuda:0')]
train: [0.22306740234375, tensor(0.9344, device='cuda:0')]
valid: [0.16564715576171876, tensor(0.9514, device='cuda:0')]

Hooks

Finding out what is inside our model.

Manual Insertion

If we wanted to record the mean and std of the activations of every layer during training we could manually iterate through the layers on every forward pass and gather those statistics.

We can do this manually:

In [38]:
class SeqModel(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        self.act_means = [[] for _ in layers]
        self.act_stds = [[] for _ in layers]
        
    def __call__(self, x):
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            self.act_means[idx].append(x.data.mean())
            self.act_stds[idx].append(x.data.std())
        return x
    
    def __iter__(self): return iter(self.layers)
In [39]:
model = SeqModel(*get_cnn_layers(data, nfs))
learn, runner = get_runner(model, data, cbs=cbfs)
In [40]:
runner.fit(2, learn)
train: [1.88146953125, tensor(0.3676, device='cuda:0')]
valid: [0.764405029296875, tensor(0.7519, device='cuda:0')]
train: [0.4974050390625, tensor(0.8445, device='cuda:0')]
valid: [0.325291455078125, tensor(0.9039, device='cuda:0')]
In [41]:
len(model.act_means)
Out[41]:
7
In [42]:
for l in model.act_means: plt.plot(l)
plt.legend(range(len(model.act_means)))
Out[42]:
<matplotlib.legend.Legend at 0x2bb2201b5b0>
In [43]:
for l in model.act_stds: plt.plot(l)
plt.legend(range(6))
Out[43]:
<matplotlib.legend.Legend at 0x2bb21c7d5b0>

If we look at the first 10 iterations we can see the mean of the first layer is right around 0.14 which is higher than we'd like but close enough.

In [44]:
model.act_means[0][:10]
Out[44]:
[tensor(0.2271, device='cuda:0'),
 tensor(0.2288, device='cuda:0'),
 tensor(0.2277, device='cuda:0'),
 tensor(0.2288, device='cuda:0'),
 tensor(0.2286, device='cuda:0'),
 tensor(0.2264, device='cuda:0'),
 tensor(0.2274, device='cuda:0'),
 tensor(0.2285, device='cuda:0'),
 tensor(0.2297, device='cuda:0'),
 tensor(0.2337, device='cuda:0')]
In [45]:
for l in model.act_means: plt.plot(l[:10])
plt.legend(range(6))
Out[45]:
<matplotlib.legend.Legend at 0x2bb2205ce20>

The standard deviation however is a problem.

The first layer starts out around .5 and then every subsequent layer exponentially decreased until they nearly collapse. This means the variance - or the space those later layers are learning - contracts.

In [46]:
model.act_stds[0][:10]
Out[46]:
[tensor(0.4344, device='cuda:0'),
 tensor(0.4372, device='cuda:0'),
 tensor(0.4354, device='cuda:0'),
 tensor(0.4374, device='cuda:0'),
 tensor(0.4383, device='cuda:0'),
 tensor(0.4349, device='cuda:0'),
 tensor(0.4372, device='cuda:0'),
 tensor(0.4389, device='cuda:0'),
 tensor(0.4434, device='cuda:0'),
 tensor(0.4499, device='cuda:0')]
In [47]:
for l in model.act_stds: plt.plot(l[:10])
plt.legend(range(6))
Out[47]:
<matplotlib.legend.Legend at 0x2bb2220fb20>

Pytorch Hooks

A hook is basically a function that is executed when either forward or backward is called.

Tensor Hooks

Normally we make some tensors with require_grad do some type of operation with them, like the forward pass through a network layer, and then call .backward to calculate the gradients so we can adjust the weights and repeat.

When that process completes we can see what gradients in the grad attribute but we're kind of blind as to what is happening during the calculation:

In [48]:
a = torch.ones(5) # make a tensor
a.requires_grad = True

b = 2*a # do some operation with it
b.retain_grad()

c = b.mean() # calculate some type of final output
c.backward() # find the gradients of the output w.r.t the inputs

print(f'a: {a}', f'b: {b}', f'c: {c}',  sep='\n')
print(f'a grad: {a.grad}', f'b grad: {b.grad}', sep='\n')
a: tensor([1., 1., 1., 1., 1.], requires_grad=True)
b: tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
c: 2.0
a grad: tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
b grad: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

This time registering a hook on b so we get some feedback - telemetry - on the process:

In [49]:
a = torch.ones(5)
a.requires_grad = True

b = 2*a
b.retain_grad()

b.register_hook(lambda x: print(x))  
c = b.mean()
In [50]:
c.backward() 
tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

Here we can see the hook in action. After the hook is registered on b it prints when backward is called.

In [51]:
print(f'a: {a}', f'b: {b}', f'c: {c}',  sep='\n')
print(f'a grad: {a.grad}', f'b grad: {b.grad}', sep='\n')
a: tensor([1., 1., 1., 1., 1.], requires_grad=True)
b: tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
c: 2.0
a grad: tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
b grad: tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

Module Hooks

In [52]:
class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3,10,2, stride = 2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160,5)
        
    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.fc1(self.flatten(x))
  
In [53]:
def hook_fn(m, i, o):
    print(f'Module: {m}')
    
    print('-'*10, "Input Grad", '-'*10)

    for grad in i:
        try:
            print(grad.shape)
        except AttributeError: 
            print ("None found for Gradient")
    
    print('-'*10, "Output Grad", '-'*10)
    
    for grad in o:  
        try:
            print(grad.shape)
        except AttributeError: 
            print ("None found for Gradient")
    
    print("\n")    
In [54]:
net = myNet()

net.conv.register_backward_hook(hook_fn)
net.fc1.register_backward_hook(hook_fn)
Out[54]:
<torch.utils.hooks.RemovableHandle at 0x2bb22061e50>
In [55]:
inp = torch.randn(1,3,8,8)
out = net(inp)

Now when we call .backward our hook should go into action:

In [56]:
(1 - out.mean()).backward()
Module: Linear(in_features=160, out_features=5, bias=True)
---------- Input Grad ----------
torch.Size([5])
torch.Size([5])
---------- Output Grad ----------
torch.Size([5])


Module: Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
---------- Input Grad ----------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
---------- Output Grad ----------
torch.Size([1, 10, 4, 4])


We can see the backward pass traveling through these two Modules.

First, the linear layer where is gets two inputs: one from the minus 1 and the other fom

Forward pass Hooks on CNN model

Alright now we are ready to use hooks to capture the means and stds of our forward pass activations like we did manually above.

In [57]:
model = get_cnn_model(data, nfs)
learn, run = get_runner(model, data, cbs=cbfs)
In [58]:
act_means = [[] for _ in model]
act_stds = [[] for _ in model]

The hook is a function - basically a callback - that takes four args:

  • layer number
  • module
  • input
  • output
In [59]:
def append_stats(i, mod, inp, out):
    act_means[i].append(out.data.mean())
    act_stds[i].append(out.data.std())
In [60]:
for i, m in enumerate(model): m.register_forward_hook(partial(append_stats, i))
In [61]:
run.fit(1, learn)
train: [2.1929325, tensor(0.2003, device='cuda:0')]
valid: [1.4022546875, tensor(0.5156, device='cuda:0')]
In [62]:
for o in act_means: plt.plot(o)
plt.legend(range(5));
In [63]:
for o in act_stds: plt.plot(o)
plt.legend(range(5));

The spikes are a problem.

Hook Class

model.children() vs model.modules()

model.children()

Returns an iterator over immediate children modules
In [64]:
c = list(model.children());c
Out[64]:
[Sequential(
   (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (1): ReLU()
 ),
 Sequential(
   (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 Sequential(
   (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 Sequential(
   (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 AdaptiveAvgPool2d(output_size=1),
 Lambda(),
 Linear(in_features=32, out_features=10, bias=True)]
In [65]:
for i in c: print(type(i))
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>
<class '__main__.Lambda'>
<class 'torch.nn.modules.linear.Linear'>

model.modules() recursively goes into each module in the model.

The model is essentially a tree structure.

In [66]:
c = list(model.modules()); c
Out[66]:
[Sequential(
   (0): Sequential(
     (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
     (1): ReLU()
   )
   (1): Sequential(
     (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): ReLU()
   )
   (2): Sequential(
     (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): ReLU()
   )
   (3): Sequential(
     (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): ReLU()
   )
   (4): AdaptiveAvgPool2d(output_size=1)
   (5): Lambda()
   (6): Linear(in_features=32, out_features=10, bias=True)
 ),
 Sequential(
   (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (1): ReLU()
 ),
 Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
 ReLU(),
 Sequential(
   (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
 ReLU(),
 Sequential(
   (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
 ReLU(),
 Sequential(
   (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (1): ReLU()
 ),
 Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
 ReLU(),
 AdaptiveAvgPool2d(output_size=1),
 Lambda(),
 Linear(in_features=32, out_features=10, bias=True)]
In [67]:
for i in c: print(type(i))
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>
<class '__main__.Lambda'>
<class 'torch.nn.modules.linear.Linear'>

Hook class cont'd

In [68]:
#export
def children(m): return list(m.children())
In [69]:
#export
class Hook():
    def __init__(self, m, f):
        self.hook = m.register_forward_hook(partial(f, self))
        
    def remove(self): 
        self.hook.remove()
    
    def __del__(self): 
        self.remove()
In [70]:
#export
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([],[])
    means, std = hook.stats
    means.append(outp.data.mean())
    std.append(outp.data.std())
In [71]:
model = get_cnn_model(data, nfs)
learn,run = get_runner(model, data, lr=0.5, cbs=cbfs)
In [72]:
hooks = [Hook(l, append_stats) for l in children(model[:4])]
In [73]:
run.fit(1, learn)
train: [1.65319671875, tensor(0.4401, device='cuda:0')]
valid: [0.9475978515625, tensor(0.7339, device='cuda:0')]
In [74]:
hooks[0].stats[0][:4]
Out[74]:
[tensor(0.1966, device='cuda:0'),
 tensor(0.1971, device='cuda:0'),
 tensor(0.1960, device='cuda:0'),
 tensor(0.1953, device='cuda:0')]
In [75]:
for h in hooks:
    plt.plot(h.stats[0])
    h.remove()
plt.legend(range(4))
Out[75]:
<matplotlib.legend.Legend at 0x2bb2228e0a0>
In [76]:
for h in hooks:
    plt.plot(h.stats[1])
    h.remove()
plt.legend(range(4))
Out[76]:
<matplotlib.legend.Legend at 0x2bb21e9a400>

We see that we basically have the same problem as above. We have these spikes that appear.

Hooks Class v2 - container

ListContainer

We're going to write a helpful ListContainer class that will make it easier to store and access objects inside.

In [77]:
#export
class ListContainer():
    def __init__(self, items):
        self.items = listify(items) #turns items into a list
        
    def __getitem__(self, idx):
        if isinstance(idx, (int, slice)): # if int or slice 
            return self.items[idx] 
        if isinstance(idx[0], bool): # if bool mask
            assert len(idx) == len(self) # check len
            return [o for m,o in zip(idx, self.items) if m] # zip and return `True` items
        return [self.items[i] for i in idx] # else if list return idx numbers
    
    def __len__(self):
        return len(self.items)
    
    def __iter__(self):
        return iter(self.items)
    
    def __setitem__(self, i, o): # Defines behavior for when an item is assigned to, using the notation self[nkey] = value
        self.items[i] = o
    
    def __delitem__(self, i):
        del(self.items[i])
    
    def __repr__(self):
        res = f'{self.__class__.__name__} has ({len(self)} items)\n {self.items[:10]}'
        if len(self)>10: res = res[:-1] + '...]'
        return res

If it has 10 items then it prints out the items...

In [78]:
ListContainer(range(10))
Out[78]:
ListContainer has (10 items)
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

But if there are more than 10 we get '...'

In [79]:
ListContainer(range(100))
Out[79]:
ListContainer has (100 items)
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9...]

It can also take integer indexing and Boolean masking:

In [80]:
t = ListContainer(range(10))
t[[1,2]]
Out[80]:
[1, 2]
In [81]:
t[[False]*8 + [True,False]]
Out[81]:
[8]

With this ListContainer we can make a new Hooks class.

__enter__(self)

Defines what the context manager should do at the beginning of the block created by the with statement. Note that the return value of __enter__ is bound to the target of the with statement, or the name after the as.


_exit__(self, exception_type, exception_value, traceback)

Defines what the context manager should do after its block has been executed (or terminates). It can be used to handle exceptions, perform cleanup, or do something always done immediately after the action in the block. If the block executes successfully, exception_type, exception_value, and traceback will be None. Otherwise, you can choose to handle the exception or let the user handle it; if you want to handle it, make sure __exit__ returns True after all is said and done. If you don't want the exception to be handled by the context manager, just let it happen.
In [82]:
#export

class Hooks(ListContainer):
    
    def __init__(self, ms, f):
        super().__init__([Hook(m,f) for m in ms]) # stores hooks away inside ListContainer
    
    def __enter__(self, *args):
        return self
    
    def __exit__(self, *args):
        self.remove()
    
    def __del__(self): 
        self.remove()
    
    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)
        
    def remove(self):
        for h in self: h.remove()
In [83]:
model = get_cnn_model(data, nfs).cuda()
learn,run = get_runner(model, data, lr=0.9, cbs=cbfs)
In [84]:
hooks = Hooks(model, append_stats)
In [85]:
hooks
Out[85]:
Hooks has (7 items)
 [<__main__.Hook object at 0x000002BB24343F40>, <__main__.Hook object at 0x000002BB243439D0>, <__main__.Hook object at 0x000002BB243533D0>, <__main__.Hook object at 0x000002BB24353BE0>, <__main__.Hook object at 0x000002BB2226C070>, <__main__.Hook object at 0x000002BB2226CBE0>, <__main__.Hook object at 0x000002BB21AABAC0>]
In [86]:
hooks.remove()

Let's pull a batch of training data and put it through the first layer of our model:

In [87]:
x, y = next(iter(data.train_dl))
x = mnist_resize(x).cuda()
x.mean(), x.std()
Out[87]:
(tensor(-0.0075, device='cuda:0'), tensor(0.9912, device='cuda:0'))

And mean is nearly zero and the std is about 1. Great.

Now let's go ahead and put it through our first layer...

In [88]:
p = model[0](x)
p.mean(), p.std()
Out[88]:
(tensor(0.2140, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.5025, device='cuda:0', grad_fn=<StdBackward0>))

Something is wrong. Our mean jumps up and our std drops.

Let's initialize our model parameters:

In [89]:
for l in model:
    if isinstance(l, nn.Sequential):
        print('Initialized', l[0].weight.shape)
        init.kaiming_normal_(l[0].weight)
        l[0].bias.data.zero_()
Initialized torch.Size([8, 1, 5, 5])
Initialized torch.Size([16, 8, 3, 3])
Initialized torch.Size([32, 16, 3, 3])
Initialized torch.Size([32, 32, 3, 3])
In [90]:
p = model[0](x)
p.mean(), p.std()
Out[90]:
(tensor(0.5181, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.9297, device='cuda:0', grad_fn=<StdBackward0>))

Better. But our mean is still too high.

Having given an __enter__ and __exit__ method to our Hooks class, we can use it as a context manager. This makes sure that onces we are out of the with block, all the hooks have been removed and aren't there to pollute our memory.

This is a chunky piece of code... I'll need some time with it.

In [91]:
with Hooks(model, append_stats) as hooks: # open context manager
    run.fit(2, learn) # train for 2 epochs
    
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10,4)) # makes 2 plots
    
    for h in hooks:
        ms, ss = h.stats
        ax0.plot(ms[:10])
        ax1.plot(ss[:10])
    plt.legend(range(6))
    
    fig,(ax0,ax1) = plt.subplots(1,2, figsize=(10,4))
    for h in hooks:
        ms, ss = h.stats
        ax0.plot(ms)
        ax1.plot(ss)
    plt.legend(range(6))
    
train: [2.248575625, tensor(0.2117, device='cuda:0')]
valid: [1.9958736328125, tensor(0.2561, device='cuda:0')]
train: [2.21365671875, tensor(0.1943, device='cuda:0')]
valid: [1.755112890625, tensor(0.3733, device='cuda:0')]

These plots give us an idea of where our mean and standard deviation are throughout the training process for each layer.

A few observations:

  • there is a spike immediately around the 4th batch
  • layers 1-5 appear to be closely correlated
  • the first layer seems to behave differently from the rest

Other Statistics

Now let's try to visualize our activations in addition to the means and standard deviations for each layer.

Again we will need to write the hook function which we'll register with each layer by using our context manager Hook class:

In [92]:
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([],[],[])
    means, stds, hists = hook.stats
    means.append(outp.data.mean().cpu())
    stds.append(outp.data.std().cpu())
    hists.append(outp.data.cpu().histc(40, 0, 10)) #40 bins between 0 and 10

Now let's get our model and initialize the weights:

In [93]:
model = get_cnn_model(data, nfs)
learn, run = get_runner(model, data, cbs=cbfs)
In [94]:
for l in model:
    if isinstance(l, nn.Sequential):
        print('Initialized', l[0].weight.shape)
        print(l[0])
        init.kaiming_normal_(l[0].weight)
        l[0].bias.data.zero_()
Initialized torch.Size([8, 1, 5, 5])
Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
Initialized torch.Size([16, 8, 3, 3])
Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Initialized torch.Size([32, 16, 3, 3])
Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Initialized torch.Size([32, 32, 3, 3])
Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
In [95]:
with Hooks(model, append_stats) as hooks: run.fit(1, learn)
train: [1.51281265625, tensor(0.5024, device='cuda:0')]
valid: [0.365566796875, tensor(0.8908, device='cuda:0')]

Let's take a closer look at how this is all put together and what our stats look like.

hooks is an instance of our Hooks class which inherits from the ListContainer class - descriptive naming.

In [96]:
isinstance(hooks, Hooks)
Out[96]:
True

The hooks object contains 7 items. Each one corresponds to a layer of the network and is an instance of the Hook class which actually does the registering for that layer.

In [97]:
hooks
Out[97]:
Hooks has (7 items)
 [<__main__.Hook object at 0x000002BB244B57F0>, <__main__.Hook object at 0x000002BB244B5790>, <__main__.Hook object at 0x000002BB244B5640>, <__main__.Hook object at 0x000002BB244B5B50>, <__main__.Hook object at 0x000002BB244B5610>, <__main__.Hook object at 0x000002BB244B56A0>, <__main__.Hook object at 0x000002BB360C39D0>]
In [98]:
isinstance(hooks[0], Hook)
Out[98]:
True

Then the append_stats functions creates a stats attribute in each one of the hooks.

We do a calculation for a stat after every batch. So we would expect the number of stats collected would equal the number of batches in an epoch (assuming 1 training epoch consisting of training and validation)

In [99]:
len(hooks[0].stats[0]) == len(data.train_dl) + len(data.valid_dl)
Out[99]:
True

Great. Now we can index into it to pull out the saved stats that were calculated when the model was training.

In [100]:
# means layer 0
hooks[0].stats[0][:5]
Out[100]:
[tensor(0.4387),
 tensor(0.4256),
 tensor(0.4328),
 tensor(0.4493),
 tensor(0.4702)]
In [101]:
# means layer 1
hooks[1].stats[0][:5]
Out[101]:
[tensor(0.4210),
 tensor(0.3931),
 tensor(0.4004),
 tensor(0.4467),
 tensor(0.5032)]
In [102]:
# stds layer 0
hooks[0].stats[1][:5]
Out[102]:
[tensor(0.9198),
 tensor(0.8872),
 tensor(0.9080),
 tensor(0.9476),
 tensor(1.0031)]
In [103]:
# stds layer 1
hooks[1].stats[1][:5]
Out[103]:
[tensor(0.7406),
 tensor(0.6976),
 tensor(0.7262),
 tensor(0.8235),
 tensor(0.9450)]

For the activations we can see for each layer the shape of the tensor is 108 x 40: 108 being the number of batches and 40 being the number of bins we specified in the append_stats function.

In [104]:
# activations layer 0
torch.stack(hooks[0].stats[2]).shape
Out[104]:
torch.Size([108, 40])

Plotting Hook Stats as histograms:

Pytorch provides a method to automatically bin tensors for histograms

In [105]:
a = torch.stack(hooks[3].stats[2])[0];a
Out[105]:
tensor([3.9656e+04, 8.1830e+03, 6.0350e+03, 4.1800e+03, 2.6090e+03, 1.7420e+03,
        1.0920e+03, 7.3400e+02, 4.5600e+02, 2.9300e+02, 1.7200e+02, 1.1800e+02,
        8.7000e+01, 6.5000e+01, 4.1000e+01, 2.6000e+01, 2.1000e+01, 1.5000e+01,
        5.0000e+00, 1.0000e+00, 4.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
In [106]:
plt.hist(a.histc(bins=40, min=0, max=10))
Out[106]:
(array([38.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]),
 array([ 0. ,  1.8,  3.6,  5.4,  7.2,  9. , 10.8, 12.6, 14.4, 16.2, 18. ], dtype=float32),
 <a list of 10 Patch objects>)
In [107]:
def get_hist(h):
    return torch.stack(h.stats[2]).t().float().log1p()
In [108]:
mpl.rcParams['image.cmap'] = 'viridis'
In [109]:
fig, axes = plt.subplots(2,2, figsize=(15,6))
for ax, h in zip(axes.flatten(), hooks[:4]):
    ax.imshow(get_hist(h), origin='lower')
    ax.axis('off')
plt.tight_layout()

Each histogram is a different layer.

The x-axis are iterations through the data.

The y-axis is how many activations are the highest or lowest they can be.

We can see for some iterations, especially at the start, the activations tend to cluster near zero. Other iterations the activations are more spread out.

Let's try to determine how many of our activations are in that yellow strip at the bottom.

Meaning: how many of our activations are near zero?

To do this we need a function that will calcalute ??

In [110]:
def get_min(h):
    h1 = torch.stack(h.stats[2]).t().float()
    return h1[:2].sum(0)/h1.sum(0) # what percentage of the activations are zero close to zero
In [111]:
fig, axes = plt.subplots(2,2, figsize=(15,6))

for ax,h in zip(axes.flatten(), hooks[:4]):
    ax.plot(get_min(h))
    ax.set_ylim(0,1)

Most of our activations - more than 90% in some cases - are near zero and therefore, totally wasted!

Generalized ReLU

To fix the problem with most of activations near zero we'll make a new non-linear function.

We'd like to be able to use a generalize ReLU function that takes parameters.

Let's start by refactoring our get_cnn_layers function. Instead of the conv2d function to get a conv and RelU we will replace it will a function we'll pass as an arg.

In [112]:
#export
def get_cnn_layers(data, nfs, layer, **kwargs):
    nfs = [1] + nfs
    return [layer(nfs[i], nfs[i+1], 5 if i==0 else 3, **kwargs) for i in range(len(nfs)-1)] + [
        nn.AdaptiveAvgPool2d(1), Lambda(flatten), nn.Linear(nfs[-1], data.c)]

Next we'll define that layer function.

The main difference between this and the original conv2d function we've been using is that this replaces the nn.RelU module with a GeneralRelu class instance which we'll define below.

In [113]:
#export
def conv_layer(ni, nf, ks=3, stride=2, **kwargs):
    return nn.Sequential(
    nn.Conv2d(ni, nf, kernel_size=ks, padding=ks//2, stride=stride), GeneralRelu(**kwargs))

The GeneralRelu class just inherits from the nn.Module class and we define the forward pass based on the passed args:

  • If leak then use F.leaky_relu
  • If sub then just subtract a given value
  • If maxv then clamp_max the tensor.
In [114]:
#export
class GeneralRelu(nn.Module):
    def __init__(self, leak=None, sub=None, maxv=None):
        super().__init__()
        self.leak, self.sub, self.maxv = leak, sub, maxv
    
    def forward(self, x):
        x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x)
        if self.sub is not None: x.sub_(self.sub)
        if self.maxv is not None: x.clamp_max_(self.maxv)
        return x

And let's make a conv weight initializer:

In [115]:
#export
def init_cnn(m, uniform=False):
    initzer = init.kaiming_normal_ if not uniform else init.kaiming_uniform_
    for l in m:
        if isinstance(l, nn.Sequential):
            print('Layer Initialized', l[0].weight.shape)
            initzer(l[0].weight, a=0.1) # why a=0.1
            l[0].bias.data.zero_()

Lastly, let's alter our get_cnn_model so that we can pass parameters to our new generalized Relu

In [116]:
#export
def get_cnn_model(data, nfs, layer, **kwargs):
    return nn.Sequential(*get_cnn_layers(data, nfs, layer, **kwargs))

Our append_stats function above was clamping the activation histograms between 0 and 10.

We need to change that now to allow for our leaky relu to be -7 and 10:

In [117]:
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([],[],[])
    means, stds, hists = hook.stats
    means.append(outp.data.mean().cpu())
    stds.append(outp.data.std().cpu())
    hists.append(outp.data.cpu().histc(40, -7, 10)) #40 bins between -7 and 10

Let's once again plot our means and stds for the layers and see where we are with our new leaky relu:

In [118]:
model = get_cnn_model(data, nfs, conv_layer, leak=0.1, sub=0.4, maxv=6.)
init_cnn(model)
learn, run = get_runner(model, data, lr=0.9, cbs=cbfs)
Layer Initialized torch.Size([8, 1, 5, 5])
Layer Initialized torch.Size([16, 8, 3, 3])
Layer Initialized torch.Size([32, 16, 3, 3])
Layer Initialized torch.Size([32, 32, 3, 3])
In [119]:
with Hooks(model, append_stats) as hooks:
    run.fit(1, learn)
    fig, (ax0, ax1) = plt.subplots(1,2, figsize=(10,4))
    
    for h in hooks:
        ms, ss, hi = h.stats
        ax0.plot(ms[:16])
        ax1.plot(ss[:16])
        h.remove()
    plt.legend(range(5))
    
    fig, (ax0, ax1) = plt.subplots(1,2, figsize=(10,4))
    
    for h in hooks:
        ms, ss, hi = h.stats
        ax0.plot(ms)
        ax1.plot(ss)
        h.remove()
    plt.legend(range(5))
    
train: [0.4730238671875, tensor(0.8503, device='cuda:0')]
valid: [0.1350355224609375, tensor(0.9603, device='cuda:0')]
In [120]:
fig,axes = plt.subplots(2,2, figsize=(15,6))
for ax,h in zip(axes.flatten(), hooks[:4]):
    ax.imshow(get_hist(h), origin='lower')
    ax.axis('off')
plt.tight_layout()

Now we can see that some of the activations are less than zero.

The color intensity is clearly more evenly distributed. Let's hope this means less of our activations are close to zero.

Again let's look at the minimum:

In [121]:
# 19:22 ????
def get_min(h):
    h1 = torch.stack(h.stats[2]).t().float()
    return h1[19:22].sum(0)/h1.sum(0)
In [122]:
fig, axes = plt.subplots(2,2, figsize=(15, 6))
for ax, h in zip(axes.flatten(), hooks[:4]):
    ax.plot(get_min(h))
    ax.set_ylim(0,1)
plt.tight_layout()

Look at that. The amount of activations near zero is less than 20% throughout the training process.

The Generalized Leaky was the key.

Training

Let's make it easier to get our learn and run:

In [123]:
#export
def get_learn_run(data, nfs, layer, lr, cbs=None, opt_func=None, uniform=False, **kwargs):
    model = get_cnn_model(data, nfs, layer, **kwargs)
    init_cnn(model, uniform=uniform)
    return get_runner(model, data, lr=lr, cbs=cbs, opt_func=opt_func)
In [124]:
sched = combine_scheds([0.5, 0.5], [sched_cos(0.2, 1.), sched_cos(1., 0.1)])
In [125]:
learn, run = get_learn_run(data, nfs, conv_layer, lr=0.3, cbs=cbfs+[partial(ParamScheduler, 'lr', sched)])
Layer Initialized torch.Size([8, 1, 5, 5])
Layer Initialized torch.Size([16, 8, 3, 3])
Layer Initialized torch.Size([32, 16, 3, 3])
Layer Initialized torch.Size([32, 32, 3, 3])
In [126]:
run.fit(8, learn)
train: [1.031034765625, tensor(0.6704, device='cuda:0')]
valid: [0.424850927734375, tensor(0.8644, device='cuda:0')]
train: [0.31972357421875, tensor(0.9027, device='cuda:0')]
valid: [0.17987952880859376, tensor(0.9432, device='cuda:0')]
train: [0.4105227734375, tensor(0.8733, device='cuda:0')]
valid: [0.28678193359375, tensor(0.9107, device='cuda:0')]
train: [0.17106853515625, tensor(0.9489, device='cuda:0')]
valid: [0.10335501708984375, tensor(0.9667, device='cuda:0')]
train: [0.093937685546875, tensor(0.9712, device='cuda:0')]
valid: [0.1048633544921875, tensor(0.9665, device='cuda:0')]
train: [0.067144404296875, tensor(0.9790, device='cuda:0')]
valid: [0.07107382202148438, tensor(0.9786, device='cuda:0')]
train: [0.0486342578125, tensor(0.9848, device='cuda:0')]
valid: [0.06403697509765625, tensor(0.9803, device='cuda:0')]
train: [0.04022163330078125, tensor(0.9876, device='cuda:0')]
valid: [0.064214453125, tensor(0.9815, device='cuda:0')]

Now let's try with uniform.

In [127]:
learn,run = get_learn_run(data, nfs,conv_layer, lr=.5, uniform=True,
                          cbs=cbfs+[partial(ParamScheduler,'lr', sched)])
Layer Initialized torch.Size([8, 1, 5, 5])
Layer Initialized torch.Size([16, 8, 3, 3])
Layer Initialized torch.Size([32, 16, 3, 3])
Layer Initialized torch.Size([32, 32, 3, 3])
In [128]:
run.fit(8, learn)
train: [1.143342265625, tensor(0.6411, device='cuda:0')]
valid: [0.30316962890625, tensor(0.9109, device='cuda:0')]
train: [0.342483671875, tensor(0.8961, device='cuda:0')]
valid: [0.20152366943359376, tensor(0.9415, device='cuda:0')]
train: [0.2997586328125, tensor(0.9087, device='cuda:0')]
valid: [2.2584748046875, tensor(0.1733, device='cuda:0')]
train: [0.76085953125, tensor(0.7593, device='cuda:0')]
valid: [0.17687100830078126, tensor(0.9446, device='cuda:0')]
train: [0.14596837890625, tensor(0.9554, device='cuda:0')]
valid: [0.1203050048828125, tensor(0.9636, device='cuda:0')]
train: [0.089819482421875, tensor(0.9721, device='cuda:0')]
valid: [0.08982869873046875, tensor(0.9749, device='cuda:0')]
train: [0.0675619091796875, tensor(0.9794, device='cuda:0')]
valid: [0.0816985107421875, tensor(0.9771, device='cuda:0')]
train: [0.0568674560546875, tensor(0.9831, device='cuda:0')]
valid: [0.07918004760742188, tensor(0.9775, device='cuda:0')]

Similar results but Kaiming Normal worked a bit better.

Improving Accuracy

How can we improve the accuracy and what does the telemetry look like?

In [ ]:
 
In [ ]:
 
In [ ]:
 
In [129]:
#export
from IPython.display import display, Javascript
def nb_auto_export():
    display(Javascript("""{
const ip = IPython.notebook
if (ip) {
    ip.save_notebook()
    console.log('a')
    const s = `!python notebook2script.py ${ip.notebook_name}`
    if (ip.kernel) { ip.kernel.execute(s) }
}
}"""))
In [130]:
nb_auto_export()