diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.h | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6d16586c2c..cb49cb95ba 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -155,11 +155,17 @@ class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique<ShapeVerifier>(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -177,6 +183,10 @@ class HloVerifier : public HloModulePass { // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func_; }; } // namespace xla |