aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r--tensorflow/python/layers/normalization.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 11daf01670..29fb92ccb5 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -664,9 +664,16 @@ def batch_normalization(inputs,
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
- need to be added as a dependency to the `train_op`. For example:
+ need to be added as a dependency to the `train_op`. Also, be sure to add
+ any batch_normalization ops before getting the update_ops collection.
+ Otherwise, update_ops will be empty, and training/inference will not work
+ properly. For example:
```python
+ x_norm = tf.layers.batch_normalization(x, training=training)
+
+ # ...
+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)