What is difference between tf.truncated_normal and tf.random_normal?

MathMachine LearningTensorflow

Math Problem Overview


tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) outputs random values from a normal distribution.

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) outputs random values from a truncated normal distribution.

I tried googling 'truncated normal distribution'. But didn't understand much.

Math Solutions


Solution 1 - Math

The documentation says it all: For truncated normal distribution: >The values are drawn from a normal distribution with specified mean and standard deviation, discarding and re-drawing any samples that are more than two standard deviations from the mean.

Most probably it is easy to understand the difference by plotting the graph for yourself (%magic is because I use jupyter notebook):

import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline  

n = 500000
A = tf.truncated_normal((n,))
B = tf.random_normal((n,))
with tf.Session() as sess:
    a, b = sess.run([A, B])

And now

plt.hist(a, 100, (-4.2, 4.2));
plt.hist(b, 100, (-4.2, 4.2));

enter image description here


The point for using truncated normal is to overcome saturation of tome functions like sigmoid (where if the value is too big/small, the neuron stops learning).

Solution 2 - Math

tf.truncated_normal() selects random numbers from a normal distribution whose mean is close to 0 and values are close to 0. For example, from -0.1 to 0.1. It's called truncated because your cutting off the tails from a normal distribution.

tf.random_normal() selects random numbers from a normal distribution whose mean is close to 0, but values can be a bit further apart. For example, from -2 to 2.

In machine learning, in practice, you usually want your weights to be close to 0.

Solution 3 - Math

The API documentation for tf.truncated_normal() describes the function as:

> Outputs random values from a truncated normal distribution. > > The generated values follow a normal distribution with specified mean and standard deviation, except that values whose magnitude is > more than 2 standard deviations from the mean are dropped and > re-picked.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionTarun WadhwaView Question on Stackoverflow
Solution 1 - MathSalvador DaliView Answer on Stackoverflow
Solution 2 - MathKenanView Answer on Stackoverflow
Solution 3 - MathMartin SvedinView Answer on Stackoverflow