aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 17:10:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 17:15:13 -0700
commit3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch)
tree926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/shape_inference.cc
parentaacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff)
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index c888bbf144..a4ea2b28f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1779,6 +1779,51 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return ShapeUtil::MakeTupleShape(operand_shape_values);
}
+/* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
+ const Shape& shape, int64 split_dimension, int64 concat_dimension,
+ int64 split_count) {
+ TF_RET_CHECK(split_count > 0);
+ if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
+ return InvalidArgument(
+ "AllToAll split_dimension %lld is out-of-bounds in shape %s.",
+ split_dimension, ShapeUtil::HumanString(shape).c_str());
+ }
+ if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
+ return InvalidArgument(
+ "AllToAll concat_dimension %lld is out-of-bounds in shape %s.",
+ concat_dimension, ShapeUtil::HumanString(shape).c_str());
+ }
+ if (shape.dimensions(split_dimension) % split_count != 0) {
+ return InvalidArgument(
+ "AllToAll split dimension size %lld must be dividable by split_count "
+ "%lld.",
+ shape.dimensions(split_dimension), split_count);
+ }
+ std::vector<int64> new_dimensions(shape.dimensions().begin(),
+ shape.dimensions().end());
+ new_dimensions[split_dimension] /= split_count;
+ new_dimensions[concat_dimension] *= split_count;
+ return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ // An Alltoall HLO instruction receives N operands (with the same shape) and
+ // returns a tuple that contains N array shapes.
+ TF_RET_CHECK(!operand_shapes.empty());
+ for (int i = 0; i < operand_shapes.size(); i++) {
+ if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
+ return InvalidArgument(
+ "HLO all-to-all has operands with different shapes: the 0th "
+ "operand shape %s, but the %dth operand has shape %s.",
+ ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i,
+ ShapeUtil::HumanString(*operand_shapes[i]).c_str());
+ }
+ }
+
+ return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,