diff options
author | 2018-08-28 11:37:18 -0700 | |
---|---|---|
committer | 2018-08-28 11:41:45 -0700 | |
commit | 6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch) | |
tree | af6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | 13c7499d5454b870eb3604d6b0ca241685cabe18 (diff) |
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index b04d2a7ba6..a04af8b0aa 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1844,6 +1844,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, |