aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index d5d497176d..0aadb98a40 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -204,6 +204,13 @@ class ShapeInference {
static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
PrimitiveType new_element_type);
+ // Helper that validates the given operand shape can be bitcast converted to
+ // the target output_shape via a bitcast convert instruction -- the
+ // requirement is that the shape is identical except for the element type and
+ // the element types have identical bit-widths.
+ static StatusOr<Shape> InferBitcastConvertShape(
+ const Shape& operand_shape, PrimitiveType new_element_type);
+
// Helper that validates the input data type for a reduce-precision operation,
// and returns the result shape.
static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,