aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 11:37:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 11:41:45 -0700
commit6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch)
treeaf6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/shape_inference.cc
parent13c7499d5454b870eb3604d6b0ca241685cabe18 (diff)
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index b04d2a7ba6..a04af8b0aa 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1844,6 +1844,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
}
+/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape));
+ return shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,