diff options
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 34b663119e..41846ae0cd 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -364,6 +364,23 @@ def batch_normalization(inputs, Sergey Ioffe, Christian Szegedy + Note: the operations which update the `moving_mean` and `moving_variance` + variables will not be added as dependencies of your training operation and so + must be run separately. For example: + ``` + extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + sess.run([train_op, extra_update_ops], ...) + ``` + Alternatively, add the operations as a dependency to your training operation + manually, and then just run your training operation as normal: + ``` + extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(extra_update_ops): + train_op = optimizer.minimize(loss) + ... + sess.run([train_op], ...) + ``` + Arguments: inputs: Tensor input. axis: Integer, the axis that should be normalized (typically the features |