aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/ops/tpu_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/ops/tpu_ops.py')
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py64
1 files changed, 57 insertions, 7 deletions
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 3ed571aff9..d92a0652bb 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -38,6 +38,62 @@ if platform.system() != "Windows":
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
+ def _create_default_group_assignment():
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ logging.warning(
+ "cross_replica_sum should be used within a tpu_shard_context, but "
+ "got unset number_of_shards. Assuming 1.")
+ num_shards = 1
+ group_assignment = [list(range(num_shards))]
+ return group_assignment
+
+ def all_to_all(x,
+ concat_dimension,
+ split_dimension,
+ split_count,
+ group_assignment=None,
+ name=None):
+ """Exchange data across TPU replicas.
+
+ Args:
+ x: The local tensor.
+ concat_dimension: The dimension number to concatenate.
+ split_dimension: The dimension number to split.
+ split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is concatenated by data from different replicas.
+ """
+ if group_assignment is None:
+ group_assignment = _create_default_group_assignment()
+ return gen_tpu_ops.all_to_all(
+ x,
+ group_assignment,
+ concat_dimension=concat_dimension,
+ split_dimension=split_dimension,
+ split_count=split_count,
+ name=name)
+
+ @ops.RegisterGradient("AllToAll")
+ def _all_to_all_grad(op, grad):
+ # The gradient of a all-to-all is also a all-to-all but the
+ # split_dimension and concat_dimension is swapped.
+ # The graident with respect to group_assignment is None.
+ return [
+ gen_tpu_ops.all_to_all(
+ grad,
+ op.inputs[1],
+ concat_dimension=op.get_attr("split_dimension"),
+ split_dimension=op.get_attr("concat_dimension"),
+ split_count=op.get_attr("split_count")), None
+ ]
+
def cross_replica_sum(x, group_assignment=None, name=None):
"""Sum the input tensor accorss replicas according to group_assignment.
@@ -52,13 +108,7 @@ if platform.system() != "Windows":
A `Tensor` which is summed across replicas.
"""
if group_assignment is None:
- num_shards = tpu_function.get_tpu_context().number_of_shards
- if num_shards is None:
- logging.warning(
- "cross_replica_sum should be used within a tpu_shard_context, but "
- "got unset number_of_shards. Assuming 1.")
- num_shards = 1
- group_assignment = [list(range(num_shards))]
+ group_assignment = _create_default_group_assignment()
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)