How to do Xavier initialization on TensorFlow

PythonTensorflow

Python Problem Overview


I'm porting my Caffe network over to TensorFlow but it doesn't seem to have xavier initialization. I'm using truncated_normal but this seems to be making it a lot harder to train.

Python Solutions


Solution 1 - Python

Since version 0.8 there is a Xavier initializer, see here for the docs.

You can use something like this:

W = tf.get_variable("W", shape=[784, 256],
           initializer=tf.contrib.layers.xavier_initializer())

Solution 2 - Python

Just to add another example on how to define a tf.Variable initialized using Xavier and Yoshua's method:

graph = tf.Graph()
with graph.as_default():
    ...
    initializer = tf.contrib.layers.xavier_initializer()
    w1 = tf.Variable(initializer(w1_shape))
    b1 = tf.Variable(initializer(b1_shape))
    ...

This prevented me from having nan values on my loss function due to numerical instabilities when using multiple layers with RELUs.

Solution 3 - Python

In Tensorflow 2.0 and further both tf.contrib.* and tf.get_variable() are deprecated. In order to do Xavier initialization you now have to switch to:

init = tf.initializers.GlorotUniform()
var = tf.Variable(init(shape=shape))
# or a oneliner with a little confusing brackets
var = tf.Variable(tf.initializers.GlorotUniform()(shape=shape))

Glorot uniform and Xavier uniform are two different names of the same initialization type. If you want to know more about how to use initializations in TF2.0 with or without Keras refer to documentation.

Solution 4 - Python

@Aleph7, Xavier/Glorot initialization depends the number of incoming connections (fan_in), number outgoing connections (fan_out), and kind of activation function (sigmoid or tanh) of the neuron. See this: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf

So now, to your question. This is how I would do it in TensorFlow:

(fan_in, fan_out) = ...
    low = -4*np.sqrt(6.0/(fan_in + fan_out)) # use 4 for sigmoid, 1 for tanh activation 
    high = 4*np.sqrt(6.0/(fan_in + fan_out))
    return tf.Variable(tf.random_uniform(shape, minval=low, maxval=high, dtype=tf.float32))

Note that we should be sampling from a uniform distribution, and not the normal distribution as suggested in the other answer.

Incidentally, I wrote a post yesterday for something different using TensorFlow that happens to also use Xavier initialization. If you're interested, there's also a python notebook with an end-to-end example: https://github.com/delip/blog-stuff/blob/master/tensorflow_ufp.ipynb

Solution 5 - Python

A nice wrapper around tensorflow called prettytensor gives an implementation in the source code (copied directly from here):

def xavier_init(n_inputs, n_outputs, uniform=True):
  """Set the parameter initialization using the method described.
  This method is designed to keep the scale of the gradients roughly the same
  in all layers.
  Xavier Glorot and Yoshua Bengio (2010):
           Understanding the difficulty of training deep feedforward neural
           networks. International conference on artificial intelligence and
           statistics.
  Args:
    n_inputs: The number of input nodes into each output.
    n_outputs: The number of output nodes for each input.
    uniform: If true use a uniform distribution, otherwise use a normal.
  Returns:
    An initializer.
  """
  if uniform:
    # 6 was used in the paper.
    init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
    return tf.random_uniform_initializer(-init_range, init_range)
  else:
    # 3 gives us approximately the same limits as above since this repicks
    # values greater than 2 standard deviations from the mean.
    stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
    return tf.truncated_normal_initializer(stddev=stddev)

Solution 6 - Python

TF-contrib has xavier_initializer. Here is an example how to use it:

import tensorflow as tf
a = tf.get_variable("a", shape=[4, 4], initializer=tf.contrib.layers.xavier_initializer())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)

In addition to this, tensorflow has other initializers:

Solution 7 - Python

I looked and I couldn't find anything built in. However, according to this:

http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-initialization

Xavier initialization is just sampling a (usually Gaussian) distribution where the variance is a function of the number of neurons. tf.random_normal can do that for you, you just need to compute the stddev (i.e. the number of neurons being represented by the weight matrix you're trying to initialize).

Solution 8 - Python

Via the kernel_initializer parameter to tf.layers.conv2d, tf.layers.conv2d_transpose, tf.layers.Dense etc

e.g.

layer = tf.layers.conv2d(
     input, 128, 5, strides=2,padding='SAME',
     kernel_initializer=tf.contrib.layers.xavier_initializer())

https://www.tensorflow.org/api_docs/python/tf/layers/conv2d

https://www.tensorflow.org/api_docs/python/tf/layers/conv2d_transpose

https://www.tensorflow.org/api_docs/python/tf/layers/Dense

Solution 9 - Python

Just in case you want to use one line as you do with:

W = tf.Variable(tf.truncated_normal((n_prev, n), stddev=0.1))

You can do:

W = tf.Variable(tf.contrib.layers.xavier_initializer()((n_prev, n)))

Solution 10 - Python

Tensorflow 1:

W1 = tf.get_variable("W1", [25, 12288],
    initializer = tf.contrib.layers.xavier_initializer(seed=1)

Tensorflow 2:

W1 = tf.get_variable("W1", [25, 12288],
    initializer = tf.random_normal_initializer(seed=1))

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionAlejandroView Question on Stackoverflow
Solution 1 - PythonSung KimView Answer on Stackoverflow
Solution 2 - PythonSaullo G. P. CastroView Answer on Stackoverflow
Solution 3 - Pythony.selivonchykView Answer on Stackoverflow
Solution 4 - PythonDelipView Answer on Stackoverflow
Solution 5 - PythonHookedView Answer on Stackoverflow
Solution 6 - PythonSalvador DaliView Answer on Stackoverflow
Solution 7 - PythonVince GattoView Answer on Stackoverflow
Solution 8 - PythonxilefView Answer on Stackoverflow
Solution 9 - PythonTony PowerView Answer on Stackoverflow
Solution 10 - PythonmruanovaView Answer on Stackoverflow