aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index ad34a2aa18..1a5684e3c3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -286,6 +286,10 @@ class ShapeInference {
static StatusOr<Shape> InferSelectShape(const Shape& pred,
const Shape& on_true,
const Shape& on_false);
+ // Helper for inferring the shape of TupleSelect ops.
+ static StatusOr<Shape> InferTupleSelectShape(const Shape& pred,
+ const Shape& on_true,
+ const Shape& on_false);
// Helper for inferring shapes of binary operations which use degenerate
// dimension broadcasting (a dimension of size 1 in one operand is broadcast