aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/main_estimator.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
index df25b5066f..a21b573c06 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
@@ -55,8 +55,9 @@ def model_fn(features, labels, mode, params):
learning_rate, momentum=config.momentum)
logits, saved_hidden = model(inputs, training=True)
grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, model.trainable_variables), global_step=global_step)
+ with tf.control_dependencies(model.get_updates_for(inputs)):
+ train_op = optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
else: