diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ad34a2aa18..1a5684e3c3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -286,6 +286,10 @@ class ShapeInference { static StatusOr<Shape> InferSelectShape(const Shape& pred, const Shape& on_true, const Shape& on_false); + // Helper for inferring the shape of TupleSelect ops. + static StatusOr<Shape> InferTupleSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast |