diff options
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 11daf01670..29fb92ccb5 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -664,9 +664,16 @@ def batch_normalization(inputs, Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they - need to be added as a dependency to the `train_op`. For example: + need to be added as a dependency to the `train_op`. Also, be sure to add + any batch_normalization ops before getting the update_ops collection. + Otherwise, update_ops will be empty, and training/inference will not work + properly. For example: ```python + x_norm = tf.layers.batch_normalization(x, training=training) + + # ... + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) |