aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-09-28 18:41:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 18:45:56 -0700
commitd37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch)
tree1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/python/estimator
parentabd5c32c0fa6451e73b491affdd86d852a74177f (diff)
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py4
-rw-r--r--tensorflow/python/estimator/util.py8
2 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 34faf03bb0..e6d82f0db7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -468,6 +468,10 @@ class Estimator(object):
with ops.Graph().as_default():
if self._eval_distribution:
+ # We want to create the iterations variable outside the distribution
+ # scope as that is just stored on the host and mainly used to drive
+ # the loop and doesn't need to be a Mirrored/Device variable.
+ training.get_or_create_steps_per_run_variable()
with self._eval_distribution.scope():
return _evaluate()
else:
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 31e4778e72..fb110c4b7b 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import os
import time
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook):
self._finalize_fn = finalize_fn
def begin(self):
+ # We only create the init ops, but don't run it. We rely on SessionManager
+ # to run it for us.
self._init_ops = self._initialization_fn()
self._finalize_ops = self._finalize_fn()
- def after_create_session(self, session, coord):
- logging.info('Initialize system')
- session.run(self._init_ops,
- options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
-
def end(self, session):
logging.info('Finalize system.')
session.run(self._finalize_ops)