aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-07-28 17:22:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-28 17:25:35 -0700
commit2ee2fee11391952ee9e73c66ee99f265d924e15a (patch)
treee5898653a69670475db1077697b25dc599215749 /tensorflow/contrib/eager
parentbf12134843638748c5541aed1dbb8647ebf504fd (diff)
Add batch norm updates for estimator.
PiperOrigin-RevId: 206459313
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
index df25b5066f..a21b573c06 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
@@ -55,8 +55,9 @@ def model_fn(features, labels, mode, params):
learning_rate, momentum=config.momentum)
logits, saved_hidden = model(inputs, training=True)
grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
- train_op = optimizer.apply_gradients(
- zip(grads, model.trainable_variables), global_step=global_step)
+ with tf.control_dependencies(model.get_updates_for(inputs)):
+ train_op = optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
else: