aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/hlo_test_base.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.cc')
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index bdd4fd7e3d..7ab2ecda58 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
verifier_layout_sensitive,
- allow_mixed_precision_in_hlo_verifier) {}
+ allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: test_runner_(test_platform), reference_runner_(reference_platform) {
hlo_verifier_ = absl::make_unique<HloVerifier>(
/*layout_sensitive=*/verifier_layout_sensitive,
- /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {