Deep Learning Model Initialization in Detail
Kaiming initialization is discussed in detail. Initialization function for ReLU is derived in detail.
- Get Data
- Objective
- Kaiming Init (Code)
- Why Initialization is Important?
- Test above example with Kaiming Init
- Kaiming Init Detailed Derivation
- Conclusion
import gzip, pickle, math
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from fastai2.data.external import download_data # if you do not have fastai2, you can
# download the dataset manually and place
# in the extra folder
import torch
import torch.nn as nn
import torch.nn.functional as F
MNIST dataset is used for quick experimentation. The things discussed in the post are independent of the dataset used.
def get_data():
# if you do not have fastai, you can download the dataset directly from the given URL
MNIST_URL = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
path = download_data(MNIST_URL, 'extra/mnist.pkl.gz')
with gzip.open(path, 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
# By default these are numpy arrays, so convert them to pytorch tensors
return map(torch.tensor, (x_train,y_train,x_valid,y_valid))
def normalize(x, m, s): return (x-m)/s
x_train, y_train, x_valid, y_valid = get_data()
train_mean, train_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std) # use training stats for validation set also
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape
x_train.mean(), x_train.std(), x_valid.mean(), x_valid.std()
Now the input comes from a distribution of $mean = 0$ and $std = 1$.
We want activations in the neural network to have mean=0, std=1. By activation I mean the output of activation_function(Linear(x, w, b))
. Convolution layer can also be represented as a linear layer (take a vector of 0's, replace the 0's where the convolution filter would be applied).
Kaiming Init (He Init) is the most commonly used initialization. It's implementation is as follows:
w1 = torch.randn(784, 50) * math.sqrt(2/784)
b1 = torch.zeros(50)
w2 = torch.randn(50,1) * math.sqrt(2/50)
b2 = torch.zeros(1)
And that's it. You sample from a Normal Gaussian Distribution (mean=0, std=1) and then change it's standard deviation to math.sqrt(2/num_inputs). This preserves your activations std in the forward propagation.
In PyTorch you can implement this using nn.init.kaiming_normal_(...)
. The implementation details will be discussed in the end.
Now we have got the bigger picture that initialization is just multiplying by some value. Let's get into details of initialization.
Pass the input through linear and relu layers and see the standard deviation of the outputs.
m = 784
nh = 10 #Num hidden units
w1 = torch.randn(m,nh)
b1 = torch.zeros(nh) # bias initialized to 0
w2 = torch.randn(nh,1)
b2 = torch.zeros(1)
def stats(x):
# Utility function to print mean and std
return x.mean(), x.std()
stats(x_valid)
def lin(x, w, b): return x@w + b
t = lin(x_valid, w1, b1)
stats(t)
The mean and std are way off from 0 and 1 respectively. Now we can take ReLU of this.
def relu(x): return x.clamp_min(0.)
t = relu(t)
stats(t)
We can see from this example, why initialization is important. After, just one layer(relu+linear) our mean and standard deviation have diverged from their initial value of 0 and 1.
Now imagine if we had 100 of these layers. Let's try that now.
To understand why initialization is important in a neural net, we'll focus on the basic operation you have there: matrix multipications. So let's just take a vector x
, and a matrix a
initialized randomly, then multiply them 100 times (as if we had 100 layers).
x = torch.randn(512)
a = torch.randn(512, 512)
for i in range(100): x = a@x
stats(x)
The problem we get is of activation explosion. Very soon our activations go to nan i.e. their values go to a large value that cannot be represented in memory. We can see when that happens.
x = torch.randn(512)
a = torch.randn(512, 512)
for i in range(100):
x = a@x
if x.mean() != x.mean(): break # nan != nan
i
So it only took 27 multipications. Now only possible solution to mitigate this problem is to reduce the scale of the matrix a
.
x = torch.randn(512)
a = torch.randn(512,512)*0.01
for i in range(100): x = a@x
stats(x)
Now we got the problem of vanishing gradients where our activations vanish to 0.
We can try Xavier initialization to solve this problem. (Xavier init is same as Kaiming init except in the numerator we have 1 instead of 2).
x = torch.randn(512)
a = torch.randn(512,512)*math.sqrt(1/512)
for i in range(100): x = a@x
stats(x)
It works.
We can try this experiment by adding relu also, in which case we would use Kaiming init.
The difference between Xavier and Kaiming init is
- Xavier Init: Only Linear layer is considered
- Kaiming Init: Linear + ReLU layer is considered
x = torch.randn(512)
a = torch.randn(512,512)*math.sqrt(2/512)
for i in range(100):
x = a@x
x = x.clamp_min(0.)
stats(x)
And it works. The activations did not explode nor vanish.
Train a simple feedforward network on MNIST dataset and visualize the histogram of the activation values.
def lin_relu(c_in, c_out):
return nn.Sequential(
nn.Linear(c_in, c_out),
nn.ReLU(inplace=True),
)
class Net(nn.Module):
def __init__(self):
super().__init__()
# Create a linear-relu model
self.model = nn.Sequential(
*lin_relu(784, 1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
nn.Linear(1000,10), # No need of softmax as it will be in the loss function
)
self.initialize_model()
def forward(self, x):
return self.model(x)
def initialize_model(self):
# Before Xavier init, people used to sample weights from
# a uniform distribution
for layer in self.model:
if isinstance(layer, nn.Linear):
value = 1./math.sqrt(layer.weight.shape[1])
layer.weight.data.uniform_(-value, value)
layer.bias.data.fill_(0)
net = Net()
net.model
bs = 128
device = torch.device('cuda')
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
dataset = torch.utils.data.TensorDataset(x_train, y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
for epoch in range(70):
running_loss = .0
for i, data in enumerate(dataloader,0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'{epoch+1}: {running_loss/len(dataloader)}')
Now visualize the activation values by using a test image.
torch.save(net.state_dict(), 'uniform.pt')
actvs = {}
net = net.to('cpu')
images = x_valid[:10000]
out = images.clone()
for i in range(len(net.model)):
if i in (1,3,5,7,9):
out = net.model[i](out)
actvs[i] = out.squeeze(0).detach().mean(axis=0).numpy()
else:
out = net.model[i](out)
Visualize the activations using a kdeplot. It is similar to histogram, but gives a better view of the data. Alternatively, histogram can also be plotted to see the effect.
fig, ax = plt.subplots(figsize=(15,5))
ax.set_xlim(-1, 1)
for k,v in actvs.items():
sns.kdeplot(v, ax=ax, label=f'Layer {k}')
def lin_relu(c_in, c_out):
return nn.Sequential(
nn.Linear(c_in, c_out),
nn.ReLU(inplace=True),
)
class Net(nn.Module):
def __init__(self):
super().__init__()
# Create a linear-relu model
self.model = nn.Sequential(
*lin_relu(784, 1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
*lin_relu(1000,1000),
nn.Linear(1000,10), # No need of softmax as it will be in the loss function
)
self.initialize_model()
def forward(self, x):
return self.model(x)
def initialize_model(self):
for layer in self.model:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(layer.weight.data, mode='fan_out')
layer.bias.data.fill_(0)
net = Net()
bs = 128
device = torch.device('cuda')
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
dataset = torch.utils.data.TensorDataset(x_train, y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
for epoch in range(70):
running_loss = .0
for i, data in enumerate(dataloader,0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'{epoch+1}: {running_loss/len(dataloader)}')
torch.save(net.state_dict(), 'kaiming.pt')
actvs = {}
net = net.to('cpu')
images = x_valid[:10000]
out = images.clone()
for i in range(len(net.model)):
if i in (1,3,5,7,9):
out = net.model[i](out)
actvs[i] = out.squeeze(0).detach().mean(axis=0).numpy()
else:
out = net.model[i](out)
fig, ax = plt.subplots(figsize=(15,5))
ax.set_xlim(-1, 1)
for k,v in actvs.items():
sns.distplot(v, ax=ax, label=f'Layer {k}', hist=False)
We can see clear improvement from the previous figure. The activation values for later layers is greater than those observed in the previous figure. This shows we made some improvement.
I will derive the initialization formula for the forward pass only. In practice we always set the initialization values so as to maintain the mean, std for the forward pass.
Consider a layer $l$ $$\vec{y}_l = W_l\vec{x}_l+\vec{b}_l$$
- $\vec{x}_l \rightarrow$ is a $(n_l,1)$ vector that represents the activations of the previous layer $\vec{y}_{l-1}$ that were passed through an activation function $f$ i.e. $$\vec{x}_l = f(\vec{y}_{l-1})$$
- $W_l \rightarrow$ is a $(d_l,n_l)$ matrix from layer $l-1$ to layer $l$, with $d_l$ the number of filters of the convolutional layer
- $\vec{b}_l \rightarrow$ is a $(d_l,1)$ vector of biases of layer $l$ (initialized to 0)
- $\vec{y}_l \rightarrow$ is a $(d_l,1)$ vector of activations of layer l before they go through the activation function.
A few assumptions are made:
- Each element in $W_l$ is mutually independent from other elements and all the elements come from the same distribution.
- Same is true for $\vec{x}_l$.
- $\vec{x}_l$ and $W_l$ are independent of each other.
Our end goal is to have unit standard deviation in all layers of the model. To do so, we first compute the variance of the above equation.
$$Var[\vec{y}_l] = Var[W_l\vec{x}_l] + Var[\vec{b}_l] + 2*Cov(W_l\vec{x}_l, \vec{b}_l)$$
- $Var[\vec{b}_l] \rightarrow$ is 0, because $\vec{b}_l$ is initialized to a zero vector.
- $Cov(W_l\vec{x}_l, \vec{b}_l) \rightarrow$ is 0, because $\vec{b}_l$ is always initialized to 0, no matter what value the first argument takes.
Therefore, $$Var[\vec{y}_l] = Var[W_l\vec{x}_l]$$
Let $W_l = (w_{i,j})_{1\leq i\leq d_l, 1\leq j \leq n_l}$, where all $w_{i,j}$ follow the same distribution. Similarly, let $\vec{x}_l = (x_j)_{1\leq j \leq n_l}$, with all $x_j$ following the same distribution.
Plugging all this into the variance equation, we have $$ Var[\vec{y}_l] = Var \left( \begin{bmatrix} \sum\limits_{j=1}^{n_l}w_{1,j}x_j \\ \sum\limits_{j=1}^{n_l}w_{2,j}x_j \\ \vdots \\ \sum\limits_{j=1}^{n_l}w_{d_l,j}x_j \end{bmatrix} \right) $$
To derive the variance of a random vector, check these videos for more details:
Variance of $\vec{y}_l$ is defined as $Var[\vec{y}_l] = \mathbb{E}[(\vec{y}_l - \mu)(\vec{y}_l - \mu)^T]$, where $\mu$ is the expectation/mean of $\vec{y}_l$. This would give us $(n_l,n_l)$ matrix as shown below:
Now $$Cov\left(\sum\limits_{j=1}^{n_l}w_{k,j}x_j,\sum\limits_{j=1}^{n_l}w_{l,j}x_j\right) = 0\text{ , for } k\neq l$$ This is because of the assumption the three assumptions that we had defined earlier i.e.
- Each element in $W_l$ is mutually independent from other elements and all the elements come from the same distribution.
- Same is true for $\vec{x}_l$.
- $\vec{x}_l$ and $W_l$ are independent of each other.
Putting this into the above equation we have $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} Var\left(\sum\limits_{j=1}^{n_l}w_{1,j}x_j\right) & 0 & \cdots & 0 \\ 0 & Var\left(\sum\limits_{j=1}^{n_l}w_{2,j}x_j\right) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & Var\left(\sum\limits_{j=1}^{n_l}w_{d_l,j}x_j\right) \end{bmatrix} \right) $$
Because all the $w_{i,j}$ and $x_i$ are independent from each other, the variance of the sum is the sum of the variances. Thus: $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} \sum\limits_{j=1}^{n_l}Var(w_{1,j}x_j) & 0 & \cdots & 0 \\ 0 & \sum\limits_{j=1}^{n_l}Var(w_{2,j}x_j) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & \sum\limits_{j=1}^{n_l}Var(w_{d_l,j}x_j) \end{bmatrix} \right) $$
Now because all the $w_{i,j}$ and $x_i$ follow the same distribution (respectively), all the variances in the sum are equal. Thus: $$ Var[\vec{y}_l] = \mathbb{E} \left( \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} \right) $$
Next, expectation of a matrix is the expectation of each value (this is also true for a vector). $$ Var[\vec{y}_l] = \left( \begin{bmatrix} \mathbb{E}[n_l*Var(w_1x_1)] & \mathbb{E}[0] & \cdots & \mathbb{E}[0] \\ \mathbb{E}[0] & \mathbb{E}[n_l*Var(w_2x_2)] & \cdots & \vdots \\ \vdots & & \ddots & \mathbb{E}[0] \\ \mathbb{E}[0] & \cdots & \mathbb{E}[0] & \mathbb{E}[n_l*Var(w_{d_l}x_{d_l})] \end{bmatrix} \right) $$
Now,
- $\mathbb{E}[n_l*Var(w_ix_i)] = n_l*Var(w_ix_i)$ because $n_l$ is constant and $Var(w_ix_i)$ is also a constant (variance tells us how far we are from the mean (technically it is standard deviation, but you get the point)).
- $\mathbb{E}[0] = 0$ because 0 is a constant.
Putting this in the equation, we have, $$ Var[\vec{y}_l] = \left( \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} \right) $$
We can apply the same logic to expand $Var[\vec{y}_l]$ as we used to expand the right hand side (as demonstrated in the above equations).
The final equation we get is as follows:
$$ \begin{bmatrix} y_1 & 0 & \cdots & 0 \\ 0 & y_2 & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & y_{d_l} \end{bmatrix} = $$$$ \begin{bmatrix} n_l*Var(w_1x_1) & 0 & \cdots & 0 \\ 0 & n_l*Var(w_2x_2) & \cdots & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & n_l*Var(w_{d_l}x_{d_l}) \end{bmatrix} $$Comparing, both sides we get the value of variance for an activation $y_l$ $$Var[y_l] = n_l*Var[w_lx_l]$$
Now we use the fact that $w_l$ and $x_l$ are independent of each other. This means
- $Cov(w^2,x^2)=0$
- $Cov(w,x)=0$
Also, the weights ($w_l$) have zero mean, as we are sampling the weights from a random normal distribution. I also used the fact that $Var[w_l]=\mathbb{E}[w_l^2]-\mathbb{E}[w_l]^2$ where the second term is zero.
$Var[x_l] = \mathbb{E}[x_l^2] - \mathbb{E}[x_l]^2$. As $x_l$ is output of an activation function (like ReLU), it has non-zero mean, thus $\mathbb{E}[x_l]\neq 0$.
$w_l$ is initialized as a normal distribution. So half the probability lies on left of origin and half on right side of origin.
a = torch.randn(1000,1000)
plt.hist(a.flatten().numpy(), bins=100);
Our objective is to solve this equation: $$Var[y_l] = n_lVar[w_l]\mathbb{E}[x_l^2]$$
For this we have to solve $\mathbb{E}[x_l^2]=\mathbb{E}[\max(0,y_{l-1})^2]$.
We first show that $y_{l-1}$ is symmetric around 0. Formula for $y_{l-1}$ is $$y_{l-1}=w_{l-1}x_{l-1}+b_{l-1}$$ Taking the expectation of above gives:
This shows that $y_{l-1}$ is centered around 0. Also, it is symmetric around 0, as $w_{l-1}$ is symmetric around 0, this combined with the fact that $y_{l-1}$ is centered around 0 gives the conclusion that $y_{l-1}$ is symmetric around 0 (i.e. half probability lies on left of origin and half on right of origin).
The above fact can also be derived mathematically as follows (I drop the $l-1$ subscript for cleaner equations):
Now we are ready to compute $\mathbb{E}[x_l^2]$:
Using this result in the equation of variance we have:
With $L$ layers put together, we have
This product is the key to the initialization design. A proper initialization method should avoid reducing or magnifying the magnitudes of input signals exponentially. So we expect the product to take a proper scalar (e.g., 1). A sufficient condition is:
$$\frac{1}{2}n_lVar[w_l] = 1, \text{ }\forall l$$
This leads to zero-mean gaussian whose standard deviation is $\sqrt{2/n_l}$ and the biases are initialized to 0.
And this is it. To initialize the weight matrix we can use the following code:
# so n_l in this case is 784 (i.e. preserve standard deviation
# along the forward pass).
weights = torch.randn(784,10)*math.sqrt(2/784)
General Recipe for any activation function
-
Find
$$\mathbb{E}[x_l^2]=\mathbb{E}[activation\_function(y_l)],\ where\ y_l=w_{l-1}x_{l-1}+b_{l-1}$$
in terms of $Var[y_{l-1}]$. Use the fact that $y_{l-1}$ is symmetric around 0.
- Use the above value to find variance of layer l $$Var[y_l] = n_lVar[w_l]\mathbb{E}[x_l^2]$$
- Compute the variance at the last layer $$Var[y_L] = Var[y_1]*Var[y_2]\cdots*Var[y_{L-1}]$$.
- Find a way to make the variance at the last layer constant.
Step 1: Find $\mathbb{E}[x_l^2]$ in terms of $Var[y_{l-1}]$.
Step 2: Compute variance of each layer. $$ Var[y_l] = n_lVar[w_l]\left(\frac{1+a^2}{2}Var[y_{l-1}]\right) $$
Step 3: Compute variance at last layer. $$ Var[y_L] = Var[y_1]\left(\prod\limits_{l=2}^{L}\frac{1+a^2}{2}n_lVar[w_l]\right) $$
Step 4: Make the variance constant (=1 in this case would be enough).
$$\frac{1+a^2}{2}n_lVar[w_l] = 1\\Var[w_l] = \frac{2}{1+a^2}$$
This leads to a zero-mean gaussian whose standard deviation is $\sqrt{2/(1+a^2)}$ and the biases are initialized to 0.