diff options
author | 2018-09-28 18:41:31 -0700 | |
---|---|---|
committer | 2018-09-28 18:45:56 -0700 | |
commit | d37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch) | |
tree | 1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/python | |
parent | abd5c32c0fa6451e73b491affdd86d852a74177f (diff) |
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/backprop.py | 2 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 4 | ||||
-rw-r--r-- | tensorflow/python/estimator/util.py | 8 | ||||
-rw-r--r-- | tensorflow/python/training/optimizer.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/session_manager.py | 5 |
5 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 78f3198011..deac29111f 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): """If x is ResourceVariable, return its handle, else x.""" - if isinstance(x, resource_variable_ops.ResourceVariable): + if resource_variable_ops.is_resource_variable(x): x = x.handle return x diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 34faf03bb0..e6d82f0db7 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -468,6 +468,10 @@ class Estimator(object): with ops.Graph().as_default(): if self._eval_distribution: + # We want to create the iterations variable outside the distribution + # scope as that is just stored on the host and mainly used to drive + # the loop and doesn't need to be a Mirrored/Device variable. + training.get_or_create_steps_per_run_variable() with self._eval_distribution.scope(): return _evaluate() else: diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index 31e4778e72..fb110c4b7b 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -22,7 +22,6 @@ from __future__ import print_function import os import time -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training @@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook): self._finalize_fn = finalize_fn def begin(self): + # We only create the init ops, but don't run it. We rely on SessionManager + # to run it for us. self._init_ops = self._initialization_fn() self._finalize_ops = self._finalize_fn() - def after_create_session(self, session, coord): - logging.info('Initialize system') - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - def end(self, session): logging.info('Finalize system.') session.run(self._finalize_ops) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f004f3944a..30b0ed20c8 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -471,7 +471,10 @@ class Optimizer( if var_list is None: var_list = tape.watched_variables() - grads = tape.gradient(loss_value, var_list, grad_loss) + # TODO(jhseu): Figure out why GradientTape's gradients don't require loss + # to be executed. + with ops.control_dependencies([loss_value]): + grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index a2e0645ba8..5e4749f306 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export @@ -182,6 +183,10 @@ class SessionManager(object): """ self._target = master sess = session.Session(self._target, graph=self._graph, config=config) + # TODO(jhseu): Delete once tpu.initialize_system() goes away. + sess.run( + distribution_strategy_context.get_distribution_strategy().initialize() + ) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " |