diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-08 17:10:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 17:15:13 -0700 |
commit | 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch) | |
tree | 926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | aacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff) |
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c888bbf144..a4ea2b28f4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1779,6 +1779,51 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector<int64> new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, |