aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-28 10:06:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-28 10:25:00 -0700
commit11698cc8e157eefe71a60931f1e721ad327e08af (patch)
tree031f0dff18003a072780d8ff6978414fe74bfb65 /tensorflow/compiler/xla/service/shape_inference.h
parentb132f5b0d82c4ae35f48607485c8be8a26ea4c00 (diff)
Verify the output shape of HLO instructions in the HloVerifier. This change adds verification for some but not all instruction types. To support this verification, add HLO-level methods to ShapeInference. These methods will also be useful to automatically infer shape in the HloInstruction::Create* methods.
This CL also fixes some tests and transformations with malformed instructions found by the verifier. PiperOrigin-RevId: 166718979
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h26
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 5d55df92a9..96e3b46c7d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <vector>
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -31,32 +33,48 @@ limitations under the License.
namespace xla {
// For a given operation and input shapes, infers what the resulting shape is
-// for the operation. With this functionality, the user does not need to
-// specify the expected result type for computations that are built up via the
-// API -- the shape that results from an operation is inferred.
+// for the operation. With this functionality, the user does not need to specify
+// the expected result type for computations that are built up via the API --
+// the shape that results from an operation is inferred. Some methods have
+// overloads for inferring shape at the HLO level.
+// TODO(b/166374537): Complete HLO level inference overloads and use to
+// automatically infer shape in HloInstruction::Create* methods.
class ShapeInference {
public:
// Infers the shape produced by applying the given unary operation to the
// given input shape.
static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
const Shape& arg);
+ static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
+ const HloInstruction* operand);
// Infers the shape produced by applying the given binary operation to the
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
+ const HloInstruction* lhs,
+ const HloInstruction* rhs);
// Infers the shape produced by applying the given ternary operation to the
// given input shapes.
static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
const Shape& lhs, const Shape& rhs,
const Shape& ehs);
+ static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
+ const HloInstruction* lhs,
+ const HloInstruction* rhs,
+ const HloInstruction* ehs);
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- VariadicOperation operation, std::vector<const Shape*> operand_shapes);
+ VariadicOperation operation,
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ static StatusOr<Shape> InferVariadicOpShape(
+ HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
// Infers the shape produced by applying the given mapping computation shape
// to the given operand shapes.