diff options
author | 2017-02-06 12:34:46 -0800 | |
---|---|---|
committer | 2017-02-06 12:51:56 -0800 | |
commit | 4596aafcc8565e83a9d6e32670a6a11db6826a55 (patch) | |
tree | e1e43ab013e4bff10f4fb4767ee1f34659159075 | |
parent | 06429f50497ea709124d240eadd9dd14041b10d2 (diff) |
[XLA] Add CodegenTestBase::set_fast_math_disabled.
This lets us control whether fast-math is enabled during tests, now that
fast-math is not controlled via flag.
Change: 146696974
-rw-r--r-- | tensorflow/compiler/xla/tests/codegen_test_base.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/codegen_test_base.h | 5 |
2 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index c407e93683..81c0568ff9 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -45,6 +45,7 @@ std::unique_ptr<Executable> CodegenTestBase::CompileToExecutable( std::unique_ptr<HloModule> hlo_module) { auto module_config = MakeUnique<HloModuleConfig>( hlo_module->entry_computation()->ComputeProgramShape()); + module_config->set_fast_math_disabled(fast_math_disabled_); return backend_->compiler() ->Compile(std::move(hlo_module), std::move(module_config), test_hlo_dumper_, backend_->default_stream_executor()) diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h index 50c0453107..ba32aac8e4 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.h +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -41,6 +41,9 @@ class CodegenTestBase : public HloTestBase { void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module, const string& pattern); + // Sets the fast-math-disabled flag on the config we use when compiling. + void set_fast_math_disabled(bool disabled) { fast_math_disabled_ = disabled; } + protected: // Compiles hlo_module to an executable, CHECK-failing if this fails. std::unique_ptr<Executable> CompileToExecutable( @@ -49,6 +52,8 @@ class CodegenTestBase : public HloTestBase { // Runs FileCheck with the given pattern over the given string and EXPECTs // that FileCheck succeeded in matching the input. void RunFileCheck(const string& input, const string& pattern); + + bool fast_math_disabled_ = false; }; } // namespace xla |