diff options
author | 2018-08-08 17:10:33 -0700 | |
---|---|---|
committer | 2018-08-08 17:15:13 -0700 | |
commit | 3e02edc1f33fb3bfa43b5828d8ecea0dbc7738ea (patch) | |
tree | 926591cdb6e0b62c2dc8c041b1c61cbaf3c2ab85 /tensorflow/compiler/xla/service/shape_inference.h | |
parent | aacb29a4ab88f9fa27c3301977e7f2cc289a3976 (diff) |
[XLA] Add the xla interface for AllToAll.
PiperOrigin-RevId: 207971529
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 33da323b3d..c185b0a1bd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -119,11 +119,22 @@ class ShapeInference { const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr<Shape> InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr<Shape> InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr<Shape> InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // |