diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d..7ab2ecda58 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique<HloVerifier>( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) { |