Neural Network— Must know Model Training Tricks

Training a neural network for best result is still not an easy task. This post lists down many important tricks & methods which helps in training better neural network.

abhishek kushwaha
3 min readAug 14, 2019
Image source : https://i.stack.imgur.com/gzrsx.png

Having done a lot of research, paper implementations and almost training more than dozens of different CNN, LSTM models, I decided to share the tricks & methods learned during the process. These are not mine all tricks but have been learned from some paper etc. I will be writing about only those one which proved beneficial in my experiment. This will definitely help you improve your training accuracy and validation loss.

This is on-going Post, meaning I will be adding more tricks as and when I come across.

Trick 1 : No Augmentation

After finishing training the network as you regularly do, restart the training with very low learning rate for small number of epochs (3–6) WITHOUT data augmentation. This fine tuning makes much of difference and gives you 2%–3% accuracy improvement (icing on the cake). This was discovered by Prof. Ben Graham (winner of CIFAR-10 competition).

Trick 2 : Balance your data

If you do not have balanced data (meaning you do not have almost equal samples of each class) then this can be handled by changing COST function.

E.g class 1=5000 samples, class 2=10000 samples. y1,y2 are class labels & p1,p2 are predicted class probabilities, then

If, old cost function = y1*log(p1) + y2*log(p2)New cost function = y1*(10000/5000)*log(p1) + y2*log(p2)

Trick 3: Focal loss

One more way to balance data is using focal loss which performs much better than trick 2. Below is the tensorflow(1.14.0 above) implementation with alpha =.35 and gamma=3. (make sure y_pred is not the probability but logit)

def get_focal_params(y_pred):
epsilon = tf.constant(1e-9)
gamma = tf.constant(3.)
y_pred = y_pred + epsilon
pinv = 1./y_pred
pos_weight_f = (pinv - 1)**gamma
weight_f = y_pred**gamma
return pos_weight_f, weight_f
def custom_loss(y_true,y_pred):
y_pred_prob = tf.keras.backend.sigmoid(y_pred)
pos_weight_f, weight_f = get_focal_params(y_pred_prob)
alpha = tf.constant(.35)
alpha_ = 1 - alpha
alpha_div = alpha / alpha_
pos_weight = pos_weight_f * alpha_div
weight = weight_f * alpha_

l2 = weight * tf.nn.weighted_cross_entropy_with_logits\
(labels=y_true, logits=y_pred, pos_weight=pos_weight)
return l2

Trick 4: Cos / modified Tanh learning rate with Warmup

When weights are initialised randomly at the start of training, initial gradient decent step can disturb network (dead Relu). A warmup step helps weights to adjust to right direction without becoming dead. Then a cos/Tanh decreasing LR helps to reach minima adaptively.

Warmup & Cos/modified Tanh learning rate.

A simple implementation of Tanh LR for Tensorflow 1.4.0

class CustomScheduleTanh(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, warmup_steps=4000,phase_step = 3000,max_lr = .001):
super(CustomScheduleTanh, self).__init__()
self.phase_step = phase_step
self.warmup_steps = warmup_steps
self.max_lr = max_lr
self.lr=0
self.step= 0
def __call__(self, step):
self.step=step
current_shifted_step = tf.math.minimum(tf.math.maximum((step-self.warmup_steps),-3.0)*5/(self.phase_step-self.warmup_steps),5.)
arg1 = -tf.math.tanh(current_shifted_step-2.) + tf.constant(1.)
arg2 = self.max_lr*step/(self.warmup_steps)
arg3 = tf.math.maximum(self.max_lr*arg1/2.,self.max_lr/10000)

lr = tf.math.minimum(arg2, arg3)
self.lr=lr
return lr

Making class object and using it with optimizer

learning_rate_schedule = CustomScheduleTanh(warmup_steps=3000,phase_step=25000,max_lr=.001)optimizer = tf.keras.optimizers.Adam(learning_rate_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

Trick 5: Random Restart

When you see that your loss is not improving then stopping the training and restarting it with a little different learning rate will help you decrease validation loss more. (this is not guaranteed but has worked well almost always in my case).

More are coming. Meanwhile if you have any then comment it and I will add it to the post for everyone benefit.

--

--

abhishek kushwaha

A Data scientist & Deep learning engineer with Computer vision and NLP specialisation