aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 6d16586c2c..cb49cb95ba 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -155,11 +155,17 @@ class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {})
: shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
return absl::make_unique<ShapeVerifier>(layout_sensitive,
allow_mixed_precision);
- }) {}
+ }),
+ instruction_can_change_layout_func_(
+ std::move(instruction_can_change_layout_func)) {
+ CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
+ }
// Uses custom shape verification.
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
@@ -177,6 +183,10 @@ class HloVerifier : public HloModulePass {
// being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
ShapeVerifierFactory shape_verifier_factory_;
+
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace xla