aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier.h
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-01-11 06:36:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 06:40:25 -0800
commit738dfa64cb2cc7771fa6bddb582abc8f32cff373 (patch)
treefc6145dac82bb1a5572cc3c38df45e2051149977 /tensorflow/compiler/xla/service/hlo_verifier.h
parent198eca145f305fa35d9f3abd0e8261c30faa7fb8 (diff)
Allow backends to specify a custom ShapeVerifier to HloVerifier.
Remove obsolete shape_size_fn_ from HloVerifier/ShapeVerifier. Adds a rank check to FFT shape inference. PiperOrigin-RevId: 181601294
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h92
1 files changed, 88 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index e35a7f3642..6368611f32 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -18,14 +18,98 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+
namespace xla {
+// Visitor which verifies that the output shape is correctly set. Verifies
+// against the inferred shape for the instruction.
+// TODO(b/26024837): Check output shape for all instruction types.
+class ShapeVerifier : public DfsHloVisitor {
+ public:
+ Status HandleElementwiseUnary(HloInstruction* hlo) override;
+ Status HandleElementwiseBinary(HloInstruction* hlo) override;
+ Status HandleClamp(HloInstruction* clamp) override;
+ Status HandleSelect(HloInstruction* select) override;
+ Status HandleConcatenate(HloInstruction* concatenate) override;
+ Status HandleConvert(HloInstruction* convert) override;
+ Status HandleBitcastConvert(HloInstruction* convert) override;
+ Status HandleCopy(HloInstruction* copy) override;
+ Status HandleDot(HloInstruction* dot) override;
+ Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleReducePrecision(HloInstruction* reduce_precision) override;
+ Status HandleInfeed(HloInstruction*) override;
+ Status HandleOutfeed(HloInstruction*) override;
+ Status HandleRng(HloInstruction*) override;
+ Status HandleReverse(HloInstruction* reverse) override;
+ Status HandleSort(HloInstruction* sort) override;
+ Status HandleConstant(HloInstruction* constant) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleReshape(HloInstruction* reshape) override;
+ Status HandleTranspose(HloInstruction* transpose) override;
+ Status HandleParameter(HloInstruction*) override;
+ Status HandleFusion(HloInstruction*) override;
+ Status HandleCall(HloInstruction* call) override;
+ Status HandleCustomCall(HloInstruction*) override;
+ Status HandleSlice(HloInstruction* slice) override;
+ Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
+ Status HandleDynamicUpdateSlice(
+ HloInstruction* dynamic_update_slice) override;
+ Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleMap(HloInstruction* map) override;
+ Status HandleReduceWindow(HloInstruction* reduce_window) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleWhile(HloInstruction* xla_while) override;
+ Status HandleConditional(HloInstruction* conditional) override;
+ Status HandlePad(HloInstruction* pad) override;
+ Status HandleSend(HloInstruction* send) override;
+ Status HandleSendDone(HloInstruction* send_done) override;
+ Status HandleRecv(HloInstruction* recv) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
+ Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
+ Status HandleBatchNormInference(
+ HloInstruction* batch_norm_inference) override;
+ Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
+
+ Status FinishVisit(HloInstruction*) override {
+ return tensorflow::Status::OK();
+ }
+
+ protected:
+ // Check the instruction's shape against the given expected shape and return
+ // an appropriate error if there is a mismatch.
+ Status CheckShape(const HloInstruction* instruction,
+ const Shape& expected_shape);
+
+ // Overload which takes a StatusOr to reduce boilerplate in the caller.
+ Status CheckShape(const HloInstruction* instruction,
+ const StatusOr<Shape>& expected_shape_status);
+
+ // Check a unary (binary, etc) instruction's shape against the inferred shape.
+ Status CheckUnaryShape(const HloInstruction* instruction);
+ Status CheckBinaryShape(const HloInstruction* instruction);
+ Status CheckTernaryShape(const HloInstruction* instruction);
+ Status CheckVariadicShape(const HloInstruction* instruction);
+
+ // Checks if the given two instructions shares the same channel id.
+ Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2);
+};
+
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
public:
- explicit HloVerifier(const std::function<int64(const Shape&)>& shape_size_fn)
- : shape_size_fn_(shape_size_fn) {}
+ // Uses standard shape inference.
+ explicit HloVerifier() : shape_verifier_(MakeUnique<ShapeVerifier>()) {}
+ // Uses custom shape verification.
+ explicit HloVerifier(std::unique_ptr<ShapeVerifier> shape_verifier)
+ : shape_verifier_(std::move(shape_verifier)) {}
~HloVerifier() override = default;
tensorflow::StringPiece name() const override { return "verifier"; }
@@ -37,8 +121,8 @@ class HloVerifier : public HloPassInterface {
// CHECKs various invariants of a fusion instruction.
Status CheckFusionInstruction(HloInstruction* fusion) const;
- // Returns the size of a Shape in bytes.
- const std::function<int64(const Shape&)> shape_size_fn_;
+ // Verifies shapes match inferred expectations.
+ std::unique_ptr<ShapeVerifier> shape_verifier_;
};
} // namespace xla