aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r--tensorflow/python/layers/normalization.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 34b663119e..41846ae0cd 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -364,6 +364,23 @@ def batch_normalization(inputs,
Sergey Ioffe, Christian Szegedy
+ Note: the operations which update the `moving_mean` and `moving_variance`
+ variables will not be added as dependencies of your training operation and so
+ must be run separately. For example:
+ ```
+ extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ sess.run([train_op, extra_update_ops], ...)
+ ```
+ Alternatively, add the operations as a dependency to your training operation
+ manually, and then just run your training operation as normal:
+ ```
+ extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ with tf.control_dependencies(extra_update_ops):
+ train_op = optimizer.minimize(loss)
+ ...
+ sess.run([train_op], ...)
+ ```
+
Arguments:
inputs: Tensor input.
axis: Integer, the axis that should be normalized (typically the features