aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py60
1 files changed, 43 insertions, 17 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 6a64893d9a..06885bbc25 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -151,6 +151,41 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._name = name
self._unsupported_ops = []
self._pivot = pivot
+ self._replicated_vars = {}
+
+ def get_replicated_var_handle(self, var):
+ """Returns a variable handle for replicated TPU variable 'var'.
+
+ This is an method used by an experimental replicated variable
+ implementation and is not intended as a public API.
+
+ Args:
+ var: The replicated TPU variable.
+
+ Returns:
+ The handle of the TPU replicated input node.
+ """
+ handle = self._replicated_vars.get(var)
+ if handle is not None:
+ return handle
+
+ # Builds a TPUReplicatedInput node for the variable, if one does not already
+ # exist. The TPUReplicatedInput node must belong to the enclosing
+ # control-flow scope of the TPUReplicateContext.
+ # TODO(phawkins): consider changing the contract of the TPU encapsulation
+ # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
+ # instead.
+
+ # pylint: disable=protected-access
+ graph = ops.get_default_graph()
+ saved_context = graph._get_control_flow_context()
+ graph._set_control_flow_context(self.outer_context)
+ handle = tpu_ops.tpu_replicated_input(
+ [v.handle for v in var._vars], name=var.name + "/handle")
+ graph._set_control_flow_context(saved_context)
+ # pylint: enable=protected-access
+ self._replicated_vars[var] = handle
+ return handle
def report_unsupported_operations(self):
if self._unsupported_ops:
@@ -598,23 +633,14 @@ def split_compile_and_replicate(computation,
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
- # For backward compatibility reasons, we tag replicated inputs with the
- # _tpu_replicated_input attribute. This does nothing and exists only for
- # backward compatibility.
- # TODO(phawkins): delete the attr_scope after 6/28/2018.
- # pylint: disable=protected-access
- with graph._attr_scope({
- "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)
- }):
- # Add identity ops so even unused inputs are "consumed" by the
- # computation. This is to avoid orphaned TPUReplicatedInput nodes.
- # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
- # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
- computation_inputs = [
- array_ops.identity(x, name="replicated_input_{}".format(i))
- for i, x in enumerate(computation_inputs)
- ]
- # pylint: enable=protected-access
+ # Add identity ops so even unused inputs are "consumed" by the
+ # computation. This is to avoid orphaned TPUReplicatedInput nodes.
+ # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
+ # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
+ computation_inputs = [
+ array_ops.identity(x, name="replicated_input_{}".format(i))
+ for i, x in enumerate(computation_inputs)
+ ]
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.