diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/main_estimator.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/main_estimator.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index df25b5066f..a21b573c06 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -55,8 +55,9 @@ def model_fn(features, labels, mode, params): learning_rate, momentum=config.momentum) logits, saved_hidden = model(inputs, training=True) grads, loss = model.compute_gradients(saved_hidden, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, model.trainable_variables), global_step=global_step) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: |