aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-09-28 18:41:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 18:45:56 -0700
commitd37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch)
tree1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/python/training
parentabd5c32c0fa6451e73b491affdd86d852a74177f (diff)
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/optimizer.py5
-rw-r--r--tensorflow/python/training/session_manager.py5
2 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f004f3944a..30b0ed20c8 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -471,7 +471,10 @@ class Optimizer(
if var_list is None:
var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
+ # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
+ # to be executed.
+ with ops.control_dependencies([loss_value]):
+ grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index a2e0645ba8..5e4749f306 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -182,6 +183,10 @@ class SessionManager(object):
"""
self._target = master
sess = session.Session(self._target, graph=self._graph, config=config)
+ # TODO(jhseu): Delete once tpu.initialize_system() goes away.
+ sess.run(
+ distribution_strategy_context.get_distribution_strategy().initialize()
+ )
if checkpoint_dir and checkpoint_filename_with_path:
raise ValueError("Can not provide both checkpoint_dir and "