diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.h | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104..217428befa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}); ~HloTestBase() override {} |