diff options
author | 2018-01-11 06:36:19 -0800 | |
---|---|---|
committer | 2018-01-11 06:40:25 -0800 | |
commit | 738dfa64cb2cc7771fa6bddb582abc8f32cff373 (patch) | |
tree | fc6145dac82bb1a5572cc3c38df45e2051149977 /tensorflow/compiler/xla/service/hlo_verifier.h | |
parent | 198eca145f305fa35d9f3abd0e8261c30faa7fb8 (diff) |
Allow backends to specify a custom ShapeVerifier to HloVerifier.
Remove obsolete shape_size_fn_ from HloVerifier/ShapeVerifier.
Adds a rank check to FFT shape inference.
PiperOrigin-RevId: 181601294
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.h | 92 |
1 files changed, 88 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e35a7f3642..6368611f32 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,14 +18,98 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + namespace xla { +// Visitor which verifies that the output shape is correctly set. Verifies +// against the inferred shape for the instruction. +// TODO(b/26024837): Check output shape for all instruction types. +class ShapeVerifier : public DfsHloVisitor { + public: + Status HandleElementwiseUnary(HloInstruction* hlo) override; + Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleClamp(HloInstruction* clamp) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleBitcastConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleReducePrecision(HloInstruction* reduce_precision) override; + Status HandleInfeed(HloInstruction*) override; + Status HandleOutfeed(HloInstruction*) override; + Status HandleRng(HloInstruction*) override; + Status HandleReverse(HloInstruction* reverse) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleConstant(HloInstruction* constant) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleReshape(HloInstruction* reshape) override; + Status HandleTranspose(HloInstruction* transpose) override; + Status HandleParameter(HloInstruction*) override; + Status HandleFusion(HloInstruction*) override; + Status HandleCall(HloInstruction* call) override; + Status HandleCustomCall(HloInstruction*) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + + Status FinishVisit(HloInstruction*) override { + return tensorflow::Status::OK(); + } + + protected: + // Check the instruction's shape against the given expected shape and return + // an appropriate error if there is a mismatch. + Status CheckShape(const HloInstruction* instruction, + const Shape& expected_shape); + + // Overload which takes a StatusOr to reduce boilerplate in the caller. + Status CheckShape(const HloInstruction* instruction, + const StatusOr<Shape>& expected_shape_status); + + // Check a unary (binary, etc) instruction's shape against the inferred shape. + Status CheckUnaryShape(const HloInstruction* instruction); + Status CheckBinaryShape(const HloInstruction* instruction); + Status CheckTernaryShape(const HloInstruction* instruction); + Status CheckVariadicShape(const HloInstruction* instruction); + + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2); +}; + // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: - explicit HloVerifier(const std::function<int64(const Shape&)>& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} + // Uses standard shape inference. + explicit HloVerifier() : shape_verifier_(MakeUnique<ShapeVerifier>()) {} + // Uses custom shape verification. + explicit HloVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) + : shape_verifier_(std::move(shape_verifier)) {} ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } @@ -37,8 +121,8 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; - // Returns the size of a Shape in bytes. - const std::function<int64(const Shape&)> shape_size_fn_; + // Verifies shapes match inferred expectations. + std::unique_ptr<ShapeVerifier> shape_verifier_; }; } // namespace xla |