diff options
author | 2018-07-28 17:22:03 -0700 | |
---|---|---|
committer | 2018-07-28 17:25:35 -0700 | |
commit | 2ee2fee11391952ee9e73c66ee99f265d924e15a (patch) | |
tree | e5898653a69670475db1077697b25dc599215749 /tensorflow/contrib/eager | |
parent | bf12134843638748c5541aed1dbb8647ebf504fd (diff) |
Add batch norm updates for estimator.
PiperOrigin-RevId: 206459313
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/main_estimator.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index df25b5066f..a21b573c06 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -55,8 +55,9 @@ def model_fn(features, labels, mode, params): learning_rate, momentum=config.momentum) logits, saved_hidden = model(inputs, training=True) grads, loss = model.compute_gradients(saved_hidden, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, model.trainable_variables), global_step=global_step) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: |