diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index adbbcea02f..07be8e9990 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -160,9 +160,8 @@ def _fused_batch_norm( they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - if update_ops: - updates = tf.group(*update_ops) - total_loss = control_flow_ops.with_dependencies([updates], total_loss) + with tf.control_dependencies(update_ops): + train_op = optimizer.minimize(loss) One can set updates_collections=None to force the updates in place, but that can have speed penalty, especially in distributed settings. @@ -393,9 +392,8 @@ def batch_norm(inputs, they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - if update_ops: - updates = tf.group(*update_ops) - total_loss = control_flow_ops.with_dependencies([updates], total_loss) + with tf.control_dependencies(update_ops): + train_op = optimizer.minimize(loss) One can set updates_collections=None to force the updates in place, but that can have speed penalty, especially in distributed settings. |