From 11698cc8e157eefe71a60931f1e721ad327e08af Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 28 Aug 2017 10:06:06 -0700 Subject: Verify the output shape of HLO instructions in the HloVerifier. This change adds verification for some but not all instruction types. To support this verification, add HLO-level methods to ShapeInference. These methods will also be useful to automatically infer shape in the HloInstruction::Create* methods. This CL also fixes some tests and transformations with malformed instructions found by the verifier. PiperOrigin-RevId: 166718979 --- tensorflow/compiler/xla/service/shape_inference.h | 26 +++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) (limited to 'tensorflow/compiler/xla/service/shape_inference.h') diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 5d55df92a9..96e3b46c7d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -21,6 +21,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,32 +33,48 @@ limitations under the License. namespace xla { // For a given operation and input shapes, infers what the resulting shape is -// for the operation. With this functionality, the user does not need to -// specify the expected result type for computations that are built up via the -// API -- the shape that results from an operation is inferred. +// for the operation. With this functionality, the user does not need to specify +// the expected result type for computations that are built up via the API -- +// the shape that results from an operation is inferred. Some methods have +// overloads for inferring shape at the HLO level. +// TODO(b/166374537): Complete HLO level inference overloads and use to +// automatically infer shape in HloInstruction::Create* methods. class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. static StatusOr InferUnaryOpShape(UnaryOperation operation, const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const HloInstruction* operand); // Infers the shape produced by applying the given binary operation to the // given input shapes. static StatusOr InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + static StatusOr InferBinaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs); // Infers the shape produced by applying the given ternary operation to the // given input shapes. static StatusOr InferTernaryOpShape(TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs); + static StatusOr InferTernaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs, + const HloInstruction* ehs); // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - VariadicOperation operation, std::vector operand_shapes); + VariadicOperation operation, + tensorflow::gtl::ArraySlice operand_shapes); + static StatusOr InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. -- cgit v1.2.3