aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-09-28 18:41:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 18:45:56 -0700
commitd37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch)
tree1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/contrib/tpu
parentabd5c32c0fa6451e73b491affdd86d852a74177f (diff)
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py11
2 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 598da7418e..004b1012e5 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -78,7 +78,7 @@ class ReplicatedVariable(object):
if tpu_context is None:
return self._primary_var.handle
- return tpu_context.get_replicated_var_handle(self)
+ return tpu_context.get_replicated_var_handle(self._name, self._vars)
@contextlib.contextmanager
def _assign_dependencies(self):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 883e08bf47..11aaa1c66a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
- def get_replicated_var_handle(self, var):
+ def get_replicated_var_handle(self, name, vars_):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
and is not intended as a public API.
Args:
- var: The replicated TPU variable.
+ name: The common name of the variable.
+ vars_: The replicated TPU variables.
Returns:
The handle of the TPU replicated input node.
"""
- handle = self._replicated_vars.get(var)
+ handle = self._replicated_vars.get(name)
if handle is not None:
return handle
@@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
- [v.handle for v in var._vars], name=var.name + "/handle")
+ [v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
- self._replicated_vars[var] = handle
+ self._replicated_vars[name] = handle
return handle
def report_unsupported_operations(self):