aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-08-28 20:40:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 20:44:54 -0700
commit8012cf52d8c7e23766ff2a3d89a3028241de50b9 (patch)
treeedb64f00417d0e2094098aabd9f08cc8e7332e4a /tensorflow/contrib/tpu/python
parent3d35a07179d4d38d0cabac4415c550f1cbce00c0 (diff)
[TF:XLA] Change group_assignment from 1d array attribute to 2d array input tensor with shape [num_groups, num_replica_per_group].
PiperOrigin-RevId: 210656091
Diffstat (limited to 'tensorflow/contrib/tpu/python')
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py29
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py27
2 files changed, 47 insertions, 9 deletions
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index bf442d9116..3ed571aff9 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -21,8 +21,10 @@ from __future__ import print_function
import platform
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
@@ -36,10 +38,35 @@ if platform.system() != "Windows":
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
+ def cross_replica_sum(x, group_assignment=None, name=None):
+ """Sum the input tensor accorss replicas according to group_assignment.
+
+ Args:
+ x: The local tensor to the sum.
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is summed across replicas.
+ """
+ if group_assignment is None:
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ logging.warning(
+ "cross_replica_sum should be used within a tpu_shard_context, but "
+ "got unset number_of_shards. Assuming 1.")
+ num_shards = 1
+ group_assignment = [list(range(num_shards))]
+
+ return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
+ # The graident with respect to group_assignment is None.
+ return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index 74a675b645..1e11de6421 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
@@ -44,8 +43,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
- group_assignment: Optional list of group ids for applying the optimizer
- to subgroups.
+ group_assignment: Optional 2d int32 lists with shape
+ [num_groups, num_replicas_per_group] which describles how to apply
+ optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -74,11 +74,22 @@ class CrossShardOptimizer(optimizer.Optimizer):
"""
if not group_assignment:
return None
- if len(group_assignment) != num_shards:
- raise ValueError("The size of group_assignment does not equal to "
- "num_shard({0}). Got group_assignment={1}".format(
- num_shards, self._group_assignment))
- subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if not (isinstance(group_assignment, list) and
+ all(isinstance(i, list) for i in group_assignment)):
+ raise ValueError("group_assignment must be a list of list. Got {}".format(
+ group_assignment))
+
+ replica_ids = set()
+ for g in group_assignment:
+ for i in g:
+ replica_ids.add(i)
+
+ if set(range(num_shards)) != replica_ids:
+ raise ValueError("group_assignment must be a permutation of range({0})."
+ " Got group_assignment={1}".format(
+ num_shards, group_assignment))
+
+ subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else: