aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index adbbcea02f..07be8e9990 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -160,9 +160,8 @@ def _fused_batch_norm(
they need to be added as a dependency to the `train_op`, example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- if update_ops:
- updates = tf.group(*update_ops)
- total_loss = control_flow_ops.with_dependencies([updates], total_loss)
+ with tf.control_dependencies(update_ops):
+ train_op = optimizer.minimize(loss)
One can set updates_collections=None to force the updates in place, but that
can have speed penalty, especially in distributed settings.
@@ -393,9 +392,8 @@ def batch_norm(inputs,
they need to be added as a dependency to the `train_op`, example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- if update_ops:
- updates = tf.group(*update_ops)
- total_loss = control_flow_ops.with_dependencies([updates], total_loss)
+ with tf.control_dependencies(update_ops):
+ train_op = optimizer.minimize(loss)
One can set updates_collections=None to force the updates in place, but that
can have speed penalty, especially in distributed settings.