diff options
author | Jonathan Hseu <jhseu@google.com> | 2018-09-28 18:41:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 18:45:56 -0700 |
commit | d37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch) | |
tree | 1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/contrib/tpu | |
parent | abd5c32c0fa6451e73b491affdd86d852a74177f (diff) |
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 11 |
2 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 598da7418e..004b1012e5 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -78,7 +78,7 @@ class ReplicatedVariable(object): if tpu_context is None: return self._primary_var.handle - return tpu_context.get_replicated_var_handle(self) + return tpu_context.get_replicated_var_handle(self._name, self._vars) @contextlib.contextmanager def _assign_dependencies(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 883e08bf47..11aaa1c66a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._pivot = pivot self._replicated_vars = {} - def get_replicated_var_handle(self, var): + def get_replicated_var_handle(self, name, vars_): """Returns a variable handle for replicated TPU variable 'var'. This is a method used by an experimental replicated variable implementation and is not intended as a public API. Args: - var: The replicated TPU variable. + name: The common name of the variable. + vars_: The replicated TPU variables. Returns: The handle of the TPU replicated input node. """ - handle = self._replicated_vars.get(var) + handle = self._replicated_vars.get(name) if handle is not None: return handle @@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): saved_context = graph._get_control_flow_context() graph._set_control_flow_context(self.outer_context) handle = tpu_ops.tpu_replicated_input( - [v.handle for v in var._vars], name=var.name + "/handle") + [v.handle for v in vars_], name=name + "/handle") graph._set_control_flow_context(saved_context) # pylint: enable=protected-access - self._replicated_vars[var] = handle + self._replicated_vars[name] = handle return handle def report_unsupported_operations(self): |