diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d5d497176d..0aadb98a40 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -204,6 +204,13 @@ class ShapeInference { static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the given operand shape can be bitcast converted to + // the target output_shape via a bitcast convert instruction -- the + // requirement is that the shape is identical except for the element type and + // the element types have identical bit-widths. + static StatusOr<Shape> InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, |