diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/ops/tpu_ops.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/ops/tpu_ops.py | 64 |
1 files changed, 57 insertions, 7 deletions
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 3ed571aff9..d92a0652bb 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,6 +38,62 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + def _create_default_group_assignment(): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + logging.warning( + "cross_replica_sum should be used within a tpu_shard_context, but " + "got unset number_of_shards. Assuming 1.") + num_shards = 1 + group_assignment = [list(range(num_shards))] + return group_assignment + + def all_to_all(x, + concat_dimension, + split_dimension, + split_count, + group_assignment=None, + name=None): + """Exchange data across TPU replicas. + + Args: + x: The local tensor. + concat_dimension: The dimension number to concatenate. + split_dimension: The dimension number to split. + split_count: The number of splits, this number must equal to the sub-group + size(group_assignment.get_shape()[1]) + group_assignment: Optional 2d int32 lists with shape [num_groups, + num_replicas_per_group]. `group_assignment[i]` represents the replica + ids in the ith subgroup. + name: Optional op name. + + Returns: + A `Tensor` which is concatenated by data from different replicas. + """ + if group_assignment is None: + group_assignment = _create_default_group_assignment() + return gen_tpu_ops.all_to_all( + x, + group_assignment, + concat_dimension=concat_dimension, + split_dimension=split_dimension, + split_count=split_count, + name=name) + + @ops.RegisterGradient("AllToAll") + def _all_to_all_grad(op, grad): + # The gradient of a all-to-all is also a all-to-all but the + # split_dimension and concat_dimension is swapped. + # The graident with respect to group_assignment is None. + return [ + gen_tpu_ops.all_to_all( + grad, + op.inputs[1], + concat_dimension=op.get_attr("split_dimension"), + split_dimension=op.get_attr("concat_dimension"), + split_count=op.get_attr("split_count")), None + ] + def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor accorss replicas according to group_assignment. @@ -52,13 +108,7 @@ if platform.system() != "Windows": A `Tensor` which is summed across replicas. """ if group_assignment is None: - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "cross_replica_sum should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - group_assignment = [list(range(num_shards))] + group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) |