The following papers are discussed in this post:

  1. Understanding the difficulty of training deep feedforward neural networks paper
  2. Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification paper
import gzip, pickle, math
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from fastai2.data.external import download_data # if you do not have fastai2, you can
                                                # download the dataset manually and place
                                                # in the extra folder
import torch
import torch.nn as nn
import torch.nn.functional as F

Get Data

MNIST dataset is used for quick experimentation. The things discussed in the post are independent of the dataset used.

def get_data():
    # if you do not have fastai, you can download the dataset directly from the given URL
    MNIST_URL = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
    
    path = download_data(MNIST_URL, 'extra/mnist.pkl.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    
    # By default these are numpy arrays, so convert them to pytorch tensors
    return map(torch.tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): return (x-m)/s
x_train, y_train, x_valid, y_valid = get_data()
train_mean, train_std = x_train.mean(), x_train.std()

x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std) # use training stats for validation set also
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape
(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))
x_train.mean(), x_train.std(), x_valid.mean(), x_valid.std()
(tensor(-7.6999e-06), tensor(1.), tensor(-0.0059), tensor(0.9924))

Now the input comes from a distribution of $mean = 0$ and $std = 1$.

Objective

We want activations in the neural network to have mean=0, std=1. By activation I mean the output of activation_function(Linear(x, w, b)). Convolution layer can also be represented as a linear layer (take a vector of 0's, replace the 0's where the convolution filter would be applied).

Important: Batch Normalization is not used in this post. BatchNorm is still needed because if we just use Kaiming init in ResNet, the weights would explode. This happens due to the shortcut connection which doubles the variance, thus exponentially increasing the weights.

Kaiming Init (Code)

Kaiming Init (He Init) is the most commonly used initialization. It's implementation is as follows:

w1 = torch.randn(784, 50) * math.sqrt(2/784)
b1 = torch.zeros(50)
w2 = torch.randn(50,1) * math.sqrt(2/50)
b2 = torch.zeros(1)

And that's it. You sample from a Normal Gaussian Distribution (mean=0, std=1) and then change it's standard deviation to math.sqrt(2/num_inputs). This preserves your activations std in the forward propagation.

In PyTorch you can implement this using nn.init.kaiming_normal_(...). The implementation details will be discussed in the end.

Now we have got the bigger picture that initialization is just multiplying by some value. Let's get into details of initialization.

Why Initialization is Important?

Example 1

Pass the input through linear and relu layers and see the standard deviation of the outputs.

m = 784
nh = 10 #Num hidden units
w1 = torch.randn(m,nh)
b1 = torch.zeros(nh) # bias initialized to 0
w2 = torch.randn(nh,1)
b2 = torch.zeros(1)
def stats(x):
    # Utility function to print mean and std
    return x.mean(), x.std()
stats(x_valid)
(tensor(-0.0059), tensor(0.9924))
def lin(x, w, b): return x@w + b
t = lin(x_valid, w1, b1)

stats(t)
(tensor(-2.8136), tensor(26.8307))

The mean and std are way off from 0 and 1 respectively. Now we can take ReLU of this.

def relu(x): return x.clamp_min(0.)
t = relu(t)

stats(t)
(tensor(9.3641), tensor(14.4010))

We can see from this example, why initialization is important. After, just one layer(relu+linear) our mean and standard deviation have diverged from their initial value of 0 and 1.

Now imagine if we had 100 of these layers. Let's try that now.

Example 2

To understand why initialization is important in a neural net, we'll focus on the basic operation you have there: matrix multipications. So let's just take a vector x, and a matrix a initialized randomly, then multiply them 100 times (as if we had 100 layers).

x = torch.randn(512)
a = torch.randn(512, 512)
for i in range(100): x = a@x
stats(x)
(tensor(nan), tensor(nan))

The problem we get is of activation explosion. Very soon our activations go to nan i.e. their values go to a large value that cannot be represented in memory. We can see when that happens.

x = torch.randn(512)
a = torch.randn(512, 512)

for i in range(100):
    x = a@x
    if x.mean() != x.mean(): break # nan != nan
        
i
27

So it only took 27 multipications. Now only possible solution to mitigate this problem is to reduce the scale of the matrix a.

x = torch.randn(512)
a = torch.randn(512,512)*0.01

for i in range(100): x = a@x
stats(x)
(tensor(0.), tensor(0.))

Now we got the problem of vanishing gradients where our activations vanish to 0.

We can try Xavier initialization to solve this problem. (Xavier init is same as Kaiming init except in the numerator we have 1 instead of 2).

x = torch.randn(512)
a = torch.randn(512,512)*math.sqrt(1/512)

for i in range(100): x = a@x
stats(x)
(tensor(-0.1040), tensor(2.8382))

It works.

We can try this experiment by adding relu also, in which case we would use Kaiming init.

The difference between Xavier and Kaiming init is

  • Xavier Init: Only Linear layer is considered
  • Kaiming Init: Linear + ReLU layer is considered
x = torch.randn(512)
a = torch.randn(512,512)*math.sqrt(2/512)

for i in range(100):
    x = a@x
    x = x.clamp_min(0.)
    
stats(x)
(tensor(1.2160), tensor(1.7869))

And it works. The activations did not explode nor vanish.

Example 3

Train a simple feedforward network on MNIST dataset and visualize the histogram of the activation values.

def lin_relu(c_in, c_out):
    return nn.Sequential(
                nn.Linear(c_in, c_out),
                nn.ReLU(inplace=True),
            )

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Create a linear-relu model
        self.model = nn.Sequential(
                        *lin_relu(784, 1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        nn.Linear(1000,10), # No need of softmax as it will be in the loss function
                    )
        self.initialize_model()
        
    def forward(self, x):
        return self.model(x)
    
    def initialize_model(self):
        # Before Xavier init, people used to sample weights from
        # a uniform distribution
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                value = 1./math.sqrt(layer.weight.shape[1])
                layer.weight.data.uniform_(-value, value)
                layer.bias.data.fill_(0)
                
net = Net()

net.model
Sequential(
  (0): Linear(in_features=784, out_features=1000, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=1000, out_features=1000, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=1000, out_features=1000, bias=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=1000, out_features=1000, bias=True)
  (7): ReLU(inplace=True)
  (8): Linear(in_features=1000, out_features=1000, bias=True)
  (9): ReLU(inplace=True)
  (10): Linear(in_features=1000, out_features=10, bias=True)
)
bs = 128
device = torch.device('cuda')

net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

dataset = torch.utils.data.TensorDataset(x_train, y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
for epoch in range(70):
    running_loss = .0
    for i, data in enumerate(dataloader,0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f'{epoch+1}: {running_loss/len(dataloader)}')
1: 2.2928382635116575
2: 2.2065723312206758
3: 1.194048885657237
4: 0.5237417409817378
5: 0.3836347458072198
6: 0.3146293921730457
7: 0.2697077517135021
8: 0.2339217932942586
9: 0.20382011488844187
10: 0.17836485357047654
11: 0.1592684537745439
12: 0.14289972114448363
13: 0.1286666537133547
14: 0.11684789568758928
15: 0.10570534034990348
16: 0.09746135014754076
17: 0.08864978084770533
18: 0.08108058688827814
19: 0.07408739045166816
20: 0.06816183782349794
21: 0.06298899033512824
22: 0.05734651746849219
23: 0.05177393974497532
24: 0.047053486519517046
25: 0.04282796499438775
26: 0.039077873800236446
27: 0.03567600295138665
28: 0.03188744994023671
29: 0.029368048552901316
30: 0.02597655968215221
31: 0.023195109860255168
32: 0.020861454260272857
33: 0.018848492100070686
34: 0.017209559499930877
35: 0.015161331571065462
36: 0.013735617821415266
37: 0.012041004417607417
38: 0.010973294661977353
39: 0.010113576904703409
40: 0.009030632378581243
41: 0.008074290845065545
42: 0.007224243325300706
43: 0.006513816581513637
44: 0.0059723575384571
45: 0.00554538136586929
46: 0.004938531619233963
47: 0.004390296693413685
48: 0.004122716288727063
49: 0.0038293207541872294
50: 0.0034256023139907763
51: 0.003229047467884345
52: 0.002975824174399559
53: 0.002749748910084749
54: 0.002592976238483038
55: 0.0023717502848460124
56: 0.0021688310572734247
57: 0.002044256757467221
58: 0.0019431500003123895
59: 0.0017717176236403294
60: 0.001694328805957085
61: 0.0015888030234819804
62: 0.001510932248754379
63: 0.0014456520477930705
64: 0.0013670944441587496
65: 0.0013064238218924939
66: 0.0012553433672739909
67: 0.0011863003556544965
68: 0.0011301248310468135
69: 0.0010923029807133552
70: 0.001042095409371914

Now visualize the activation values by using a test image.

torch.save(net.state_dict(), 'uniform.pt')
actvs = {}
net = net.to('cpu')

images = x_valid[:10000]

out = images.clone()
for i in range(len(net.model)):
    if i in (1,3,5,7,9):
        out = net.model[i](out)
        actvs[i] = out.squeeze(0).detach().mean(axis=0).numpy()
    else:
        out = net.model[i](out)

Visualize the activations using a kdeplot. It is similar to histogram, but gives a better view of the data. Alternatively, histogram can also be plotted to see the effect.

fig, ax = plt.subplots(figsize=(15,5))
ax.set_xlim(-1, 1)
for k,v in actvs.items():
    sns.kdeplot(v, ax=ax, label=f'Layer {k}')

Test above example with Kaiming Init

def lin_relu(c_in, c_out):
    return nn.Sequential(
                nn.Linear(c_in, c_out),
                nn.ReLU(inplace=True),
            )

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Create a linear-relu model
        self.model = nn.Sequential(
                        *lin_relu(784, 1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        *lin_relu(1000,1000),
                        nn.Linear(1000,10), # No need of softmax as it will be in the loss function
                    )
        self.initialize_model()
        
    def forward(self, x):
        return self.model(x)
    
    def initialize_model(self):
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight.data, mode='fan_out')
                layer.bias.data.fill_(0)
                
net = Net()
bs = 128
device = torch.device('cuda')

net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

dataset = torch.utils.data.TensorDataset(x_train, y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
for epoch in range(70):
    running_loss = .0
    for i, data in enumerate(dataloader,0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f'{epoch+1}: {running_loss/len(dataloader)}')
1: 0.7095098866484104
2: 0.0923736249263852
3: 0.03983800867333626
4: 0.01766058729531673
5: 0.009082728055998301
6: 0.005672304465984687
7: 0.004172981537591953
8: 0.003315979051284301
9: 0.0027411463121191047
10: 0.0023460902225894806
11: 0.0020474239801749204
12: 0.0018151276721022067
13: 0.0016298005023063758
14: 0.0014792730888495077
15: 0.0013441605111345267
16: 0.0012390905131514256
17: 0.001146641082297533
18: 0.0010655345013126348
19: 0.0009982335691650708
20: 0.000935702603787948
21: 0.0008791005668731837
22: 0.0008290120758689366
23: 0.0007855443713756708
24: 0.0007456283538769453
25: 0.0007100430006782214
26: 0.0006764667108654976
27: 0.0006458146258806571
28: 0.0006196549830910488
29: 0.0005923675229916206
30: 0.0005684418173936697
31: 0.0005472387879704817
32: 0.0005264173046900676
33: 0.0005066754296422004
34: 0.0004908624558876722
35: 0.00047273215575095934
36: 0.0004578854984197861
37: 0.00044289983044832183
38: 0.0004285678219718811
39: 0.00041579040579306774
40: 0.0004036462746369533
41: 0.00039161906983607855
42: 0.000380707326798867
43: 0.0003695801998942326
44: 0.00035994008947641423
45: 0.00035058893263339996
46: 0.000341620133855404
47: 0.0003329002513335301
48: 0.0003244552044914319
49: 0.0003169629149712049
50: 0.0003093158348630636
51: 0.0003023809108596582
52: 0.0002947433923299496
53: 0.00028860634909226346
54: 0.00028179665215504477
55: 0.00027597242823013894
56: 0.00027051137712521433
57: 0.0002648675336669653
58: 0.0002595099023519418
59: 0.0002543628406830323
60: 0.0002490593359256402
61: 0.0002443719846315873
62: 0.0002396430533665877
63: 0.00023537879953017602
64: 0.00023074357364422237
65: 0.00022677934895723293
66: 0.00022275162239869436
67: 0.00021877073897765234
68: 0.00021507212748894325
69: 0.0002110987806167358
70: 0.00020773537839070344
torch.save(net.state_dict(), 'kaiming.pt')
actvs = {}
net = net.to('cpu')

images = x_valid[:10000]

out = images.clone()
for i in range(len(net.model)):
    if i in (1,3,5,7,9):
        out = net.model[i](out)
        actvs[i] = out.squeeze(0).detach().mean(axis=0).numpy()
    else:
        out = net.model[i](out)
fig, ax = plt.subplots(figsize=(15,5))
ax.set_xlim(-1, 1)
for k,v in actvs.items():
    sns.distplot(v, ax=ax, label=f'Layer {k}', hist=False)

We can see clear improvement from the previous figure. The activation values for later layers is greater than those observed in the previous figure. This shows we made some improvement.

Kaiming Init Detailed Derivation

I will derive the initialization formula for the forward pass only. In practice we always set the initialization values so as to maintain the mean, std for the forward pass.

Consider a layer $l$ $$\vec{y}_l = W_l\vec{x}_l+\vec{b}_l$$

  • $\vec{x}_l \rightarrow$ is a $(n_l,1)$ vector that represents the activations of the previous layer $\vec{y}_{l-1}$ that were passed through an activation function $f$ i.e. $$\vec{x}_l = f(\vec{y}_{l-1})$$
  • $W_l \rightarrow$ is a $(d_l,n_l)$ matrix from layer $l-1$ to layer $l$, with $d_l$ the number of filters of the convolutional layer
  • $\vec{b}_l \rightarrow$ is a $(d_l,1)$ vector of biases of layer $l$ (initialized to 0)
  • $\vec{y}_l \rightarrow$ is a $(d_l,1)$ vector of activations of layer l before they go through the activation function.

Note: Vectors are shown with an arrow on top of them. For matrices I do not use any special notation.

A few assumptions are made:

  • Each element in $W_l$ is mutually independent from other elements and all the elements come from the same distribution.
  • Same is true for $\vec{x}_l$.
  • $\vec{x}_l$ and $W_l$ are independent of each other.

Compute Variance of $\vec{y}_l$

Note: This is a heavy math section. I have explained the complete derivation step-by-step, but I also assumed that you know the Variance and Expectation formulas and what independence means. The result is what matter the most, so you can skip this section if you want.

Our end goal is to have unit standard deviation in all layers of the model. To do so, we first compute the variance of the above equation.

$$Var[\vec{y}_l] = Var[W_l\vec{x}_l] + Var[\vec{b}_l] + 2*Cov(W_l\vec{x}_l, \vec{b}_l)$$

  • $Var[\vec{b}_l] \rightarrow$ is 0, because $\vec{b}_l$ is initialized to a zero vector.
  • $Cov(W_l\vec{x}_l, \vec{b}_l) \rightarrow$ is 0, because $\vec{b}_l$ is always initialized to 0, no matter what value the first argument takes.

Therefore, $$Var[\vec{y}_l] = Var[W_l\vec{x}_l]$$

Let $W_l = (w_{i,j})_{1\leq i\leq d_l, 1\leq j \leq n_l}$, where all $w_{i,j}$ follow the same distribution. Similarly, let $\vec{x}_l = (x_j)_{1\leq j \leq n_l}$, with all $x_j$ following the same distribution.

Plugging all this into the variance equation, we have $$ Var[\vec{y}_l] = Var \left( \begin{bmatrix} \sum\limits_{j=1}^{n_l}w_{1,j}x_j \\ \sum\limits_{j=1}^{n_l}w_{2,j}x_j \\ \vdots \\ \sum\limits_{j=1}^{n_l}w_{d_l,j}x_j \end{bmatrix} \right) $$

To derive the variance of a random vector, check these videos for more details:

  1. Expectations and variance of a random vector - part 2 link
  2. Expectations and variance of a random vector - part 3 link

Variance of $\vec{y}_l$ is defined as $Var[\vec{y}_l] = \mathbb{E}[(\vec{y}_l - \mu)(\vec{y}_l - \mu)^T]$, where $\mu$ is the expectation/mean of $\vec{y}_l$. This would give us $(n_l,n_l)$ matrix as shown below:

$$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} Var\left(\sum\limits_{j=1}^{n_l}w_{1,j}x_j\right) & Cov\left(\sum\limits_{j=1}^{n_l}w_{1,j}x_j,\sum\limits_{j=1}^{n_l}w_{2,j}x_j\right) & \cdots & Cov\left(\sum\limits_{j=1}^{n_l}w_{1,j}x_j,\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j\right) \\ Cov\left(\sum\limits_{j=1}^{n_l}w_{2,j}x_j,\sum\limits_{j=1}^{n_l}w_{1,j}x_j\right) & Var\left(\sum\limits_{j=1}^{n_l}w_{2,j}x_j\right) & \cdots & \vdots \\ \vdots & & \ddots& Cov\left(\sum\limits_{j=1}^{n_l}w_{d_l-1,j}x_j,\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j\right) \\ Cov\left(\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j,\sum\limits_{j=1}^{n_l}w_{1,j}x_j\right) & \cdots & Cov\left(\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j,\sum\limits_{j=1}^{n_l}w_{d_l-1,j}x_j\right) & Var\left(\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j\right) \end{bmatrix} \right) $$

Now $$Cov\left(\sum\limits_{j=1}^{n_l}w_{k,j}x_j,\sum\limits_{j=1}^{n_l}w_{l,j}x_j\right) = 0\text{ , for } k\neq l$$ This is because of the assumption the three assumptions that we had defined earlier i.e.

  • Each element in $W_l$ is mutually independent from other elements and all the elements come from the same distribution.
  • Same is true for $\vec{x}_l$.
  • $\vec{x}_l$ and $W_l$ are independent of each other.

Putting this into the above equation we have $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} Var\left(\sum\limits_{j=1}^{n_l}w_{1,j}x_j\right) & 0 & \cdots & 0 \\ 0 & Var\left(\sum\limits_{j=1}^{n_l}w_{2,j}x_j\right) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & Var\left(\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j\right) \end{bmatrix} \right) $$

Because all the $w_{i,j}$ and $x_i$ are independent from each other, the variance of the sum is the sum of the variances. Thus: $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} \sum\limits_{j=1}^{n_l}Var(w_{1,j}x_j) & 0 & \cdots & 0 \\ 0 & \sum\limits_{j=1}^{n_l}Var(w_{2,j}x_j) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & \sum\limits_{j=1}^{n_l}Var(w_{d_l,j}x_j) \end{bmatrix} \right) $$

Now because all the $w_{i,j}$ and $x_i$ follow the same distribution (respectively), all the variances in the sum are equal. Thus: $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} \right) $$

Next, expectation of a matrix is the expectation of each value (this is also true for a vector). $$ Var[\vec{y}_l] = \left( \begin{bmatrix} \mathbb{E}[n_l*Var(w_1x_1)] & \mathbb{E}[0] & \cdots & \mathbb{E}[0] \\ \mathbb{E}[0] & \mathbb{E}[n_l*Var(w_2x_2)] & \cdots & \vdots \\ \vdots & & \ddots & \mathbb{E}[0] \\ \mathbb{E}[0] & \cdots & \mathbb{E}[0] & \mathbb{E}[n_l*Var(w_{d_l}x_{d_l})] \end{bmatrix} \right) $$

Now,

  • $\mathbb{E}[n_l*Var(w_ix_i)] = n_l*Var(w_ix_i)$ because $n_l$ is constant and $Var(w_ix_i)$ is also a constant (variance tells us how far we are from the mean (technically it is standard deviation, but you get the point)).
  • $\mathbb{E}[0] = 0$ because 0 is a constant.

Putting this in the equation, we have, $$ Var[\vec{y}_l] = \left( \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} \right) $$

We can apply the same logic to expand $Var[\vec{y}_l]$ as we used to expand the right hand side (as demonstrated in the above equations).

The final equation we get is as follows:

$$ \begin{bmatrix} y_1 & 0 & \cdots & 0 \\ 0 & y_2 & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & y_{d_l} \end{bmatrix} = $$$$ \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} $$

Comparing, both sides we get the value of variance for an activation $y_l$ $$Var[y_l] = n_l*Var[w_lx_l]$$

Note: $y_l,x_l$ are scalars. $w_l$ is $(1,n_l)$ row-vector. To keep the notation same as the original paper, I did not use the vector notation. From this point on it does not matter much that $w_l$ is a row-vector or not.
For the derivation below, I drop subscript l (as it would just clutter the equations). I use the standard formula for Variance to expand the above equation as shown.

Now we use the fact that $w_l$ and $x_l$ are independent of each other. This means

  • $Cov(w^2,x^2)=0$
  • $Cov(w,x)=0$

Also, the weights ($w_l$) have zero mean, as we are sampling the weights from a random normal distribution. I also used the fact that $Var[w_l]=\mathbb{E}[w_l^2]-\mathbb{E}[w_l]^2$ where the second term is zero.

$Var[x_l] = \mathbb{E}[x_l^2] - \mathbb{E}[x_l]^2$. As $x_l$ is output of an activation function (like ReLU), it has non-zero mean, thus $\mathbb{E}[x_l]\neq 0$.

Derivation for ReLU activation function

$$ ReLU(x) = \begin{cases} x\text{ if }x\geq 0 \\ 0\text{ if }x<0 \end{cases} = \max(x,0) $$

$w_l$ is initialized as a normal distribution. So half the probability lies on left of origin and half on right side of origin.

a = torch.randn(1000,1000)
plt.hist(a.flatten().numpy(), bins=100);

Our objective is to solve this equation: $$Var[y_l] = n_lVar[w_l]\mathbb{E}[x_l^2]$$

For this we have to solve $\mathbb{E}[x_l^2]=\mathbb{E}[\max(0,y_{l-1})^2]$.

Note: $x_l$ is the output of activation function applied to a linear layer.

We first show that $y_{l-1}$ is symmetric around 0. Formula for $y_{l-1}$ is $$y_{l-1}=w_{l-1}x_{l-1}+b_{l-1}$$ Taking the expectation of above gives:

This shows that $y_{l-1}$ is centered around 0. Also, it is symmetric around 0, as $w_{l-1}$ is symmetric around 0, this combined with the fact that $y_{l-1}$ is centered around 0 gives the conclusion that $y_{l-1}$ is symmetric around 0 (i.e. half probability lies on left of origin and half on right of origin).

The above fact can also be derived mathematically as follows (I drop the $l-1$ subscript for cleaner equations):

Now we are ready to compute $\mathbb{E}[x_l^2]$:

Using this result in the equation of variance we have:

With $L$ layers put together, we have

This product is the key to the initialization design. A proper initialization method should avoid reducing or magnifying the magnitudes of input signals exponentially. So we expect the product to take a proper scalar (e.g., 1). A sufficient condition is:

$$\frac{1}{2}n_lVar[w_l] = 1, \text{ }\forall l$$

This leads to zero-mean gaussian whose standard deviation is $\sqrt{2/n_l}$ and the biases are initialized to 0.

And this is it. To initialize the weight matrix we can use the following code:

# so n_l in this case is 784 (i.e. preserve standard deviation
# along the forward pass). 
weights = torch.randn(784,10)*math.sqrt(2/784)

General Recipe for any activation function

  1. Find

    $$\mathbb{E}[x_l^2]=\mathbb{E}[activation\_function(y_l)],\ where\ y_l=w_{l-1}x_{l-1}+b_{l-1}$$

in terms of $Var[y_{l-1}]$. Use the fact that $y_{l-1}$ is symmetric around 0.

  1. Use the above value to find variance of layer l $$Var[y_l] = n_lVar[w_l]\mathbb{E}[x_l^2]$$
  2. Compute the variance at the last layer $$Var[y_L] = Var[y_1]*Var[y_2]\cdots*Var[y_{L-1}]$$.
  3. Find a way to make the variance at the last layer constant.

Derivation for LeakyReLU activation function

$$ LeakyReLU(x) = \begin{cases} x\text{ if }x\geq 0\\ ax\text{ if }x<0 \end{cases} = \max(ax,x), a>0 $$

Step 1: Find $\mathbb{E}[x_l^2]$ in terms of $Var[y_{l-1}]$.

Step 2: Compute variance of each layer. $$ Var[y_l] = n_lVar[w_l]\left(\frac{1+a^2}{2}Var[y_{l-1}]\right) $$

Step 3: Compute variance at last layer. $$ Var[y_L] = Var[y_1]\left(\prod\limits_{l=2}^{L}\frac{1+a^2}{2}n_lVar[w_l]\right) $$

Step 4: Make the variance constant (=1 in this case would be enough).

$$\frac{1+a^2}{2}n_lVar[w_l] = 1\\Var[w_l] = \frac{2}{1+a^2}$$

This leads to a zero-mean gaussian whose standard deviation is $\sqrt{2/(1+a^2)}$ and the biases are initialized to 0.

Conclusion

Weight initialization plays an important role in the training of a model. It can speed up the training process. Kaiming init is a good default initialization technique for most of the tasks.