diff options
author | Youlong Cheng <ylc@google.com> | 2018-08-28 20:40:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 20:44:54 -0700 |
commit | 8012cf52d8c7e23766ff2a3d89a3028241de50b9 (patch) | |
tree | edb64f00417d0e2094098aabd9f08cc8e7332e4a /tensorflow/contrib/tpu/python | |
parent | 3d35a07179d4d38d0cabac4415c550f1cbce00c0 (diff) |
[TF:XLA] Change group_assignment from 1d array attribute to 2d array input tensor with shape [num_groups, num_replica_per_group].
PiperOrigin-RevId: 210656091
Diffstat (limited to 'tensorflow/contrib/tpu/python')
-rw-r--r-- | tensorflow/contrib/tpu/python/ops/tpu_ops.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py | 27 |
2 files changed, 47 insertions, 9 deletions
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index bf442d9116..3ed571aff9 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -21,8 +21,10 @@ from __future__ import print_function import platform +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top @@ -36,10 +38,35 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + def cross_replica_sum(x, group_assignment=None, name=None): + """Sum the input tensor accorss replicas according to group_assignment. + + Args: + x: The local tensor to the sum. + group_assignment: Optional 2d int32 lists with shape [num_groups, + num_replicas_per_group]. `group_assignment[i]` represents the replica + ids in the ith subgroup. + name: Optional op name. + + Returns: + A `Tensor` which is summed across replicas. + """ + if group_assignment is None: + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + logging.warning( + "cross_replica_sum should be used within a tpu_shard_context, but " + "got unset number_of_shards. Assuming 1.") + num_shards = 1 + group_assignment = [list(range(num_shards))] + + return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) + @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. - return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment")) + # The graident with respect to group_assignment is None. + return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] # This extra type checking exists to give a more helpful error message in # the common case that uint8 and int64 values are infed. Remove when both diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index 74a675b645..1e11de6421 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function @@ -44,8 +43,9 @@ class CrossShardOptimizer(optimizer.Optimizer): reduction: The reduction to apply to the shard losses. name: Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer". - group_assignment: Optional list of group ids for applying the optimizer - to subgroups. + group_assignment: Optional 2d int32 lists with shape + [num_groups, num_replicas_per_group] which describles how to apply + optimizer to subgroups. Raises: ValueError: If reduction is not a valid cross-shard reduction. @@ -74,11 +74,22 @@ class CrossShardOptimizer(optimizer.Optimizer): """ if not group_assignment: return None - if len(group_assignment) != num_shards: - raise ValueError("The size of group_assignment does not equal to " - "num_shard({0}). Got group_assignment={1}".format( - num_shards, self._group_assignment)) - subgroup_size_list = dict(collections.Counter(group_assignment)).values() + if not (isinstance(group_assignment, list) and + all(isinstance(i, list) for i in group_assignment)): + raise ValueError("group_assignment must be a list of list. Got {}".format( + group_assignment)) + + replica_ids = set() + for g in group_assignment: + for i in g: + replica_ids.add(i) + + if set(range(num_shards)) != replica_ids: + raise ValueError("group_assignment must be a permutation of range({0})." + " Got group_assignment={1}".format( + num_shards, group_assignment)) + + subgroup_size_list = [len(group) for group in group_assignment] if all(subgroup_size_list[0] == size for size in subgroup_size_list): return subgroup_size_list[0] else: |