diff options
author | 2018-03-08 23:28:53 -0800 | |
---|---|---|
committer | 2018-03-08 23:32:38 -0800 | |
commit | 26b83da42fb47015aabd6ba1aa8e6d41ff8763dc (patch) | |
tree | 4e6a31809bed21884584c6f99c5aeace09db91a3 /tensorflow/compiler/xla | |
parent | 7dbe0cf7ecc4d0560ec9081b443ada693e4e6096 (diff) |
Remove a layer of templatization
With this change
- HloTestBase always calls HloRunner with an array of non-owning Literal
pointers as arguments
- HloRunner no longer has a general LiteralPtr, but just provides explicit
overloads for std::unique_ptr<Literal> and Literal*
This was prompted by a dependent change that needs to call
HloTestBase::RunAndCompare with Literal* arguments.
PiperOrigin-RevId: 188446331
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_runner.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_runner.h | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.h | 12 |
4 files changed, 38 insertions, 45 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 41b079eb79..d65befaf84 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -110,7 +110,7 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteInternal( +StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( std::unique_ptr<HloModule> module, const tensorflow::gtl::ArraySlice<Literal*> arguments, bool run_hlo_passes) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cbaebc68be..06ce22a5b9 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -64,17 +65,27 @@ class HloRunner { const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the - // result as a Literal. The LiteralPtr type accepts Literal* or - // std::unique_ptr<Literal>. + // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - template <typename LiteralPtr> StatusOr<std::unique_ptr<Literal>> Execute( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, + const tensorflow::gtl::ArraySlice<Literal*> arguments, bool run_hlo_passes = true); + StatusOr<std::unique_ptr<Literal>> Execute( + std::unique_ptr<HloModule> module, + const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments, + bool run_hlo_passes = true) { + // Construct a vector of plain pointers for the arguments. + std::vector<Literal*> argument_pointers; + c_transform( + arguments, std::back_inserter(argument_pointers), + [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); + return Execute(std::move(module), argument_pointers, run_hlo_passes); + } + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -83,11 +94,6 @@ class HloRunner { Backend& backend(); private: - StatusOr<std::unique_ptr<Literal>> ExecuteInternal( - std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<Literal*> arguments, - bool run_hlo_passes = true); - struct EigenThreadPoolWrapper; std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_; @@ -95,19 +101,6 @@ class HloRunner { std::unique_ptr<Backend> backend_; }; -template <typename LiteralPtr> -StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( - std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, - bool run_hlo_passes) { - // Construct a vector of plain pointers for the arguments. - std::vector<Literal*> argument_pointers; - for (const auto& argument : arguments) { - argument_pointers.push_back(&*argument); - } - return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 6723c99edb..5f62c44f25 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -140,15 +140,10 @@ StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule( return std::move(reference_module); } -template <typename LiteralPtr> StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, + std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments, const optional<ErrorSpec>& error, bool run_hlo_passes, const std::function<void(HloModule*)>& reference_preprocessor) { - static_assert( - std::is_same<Literal*, LiteralPtr>::value || - std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value, - "The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>."); TF_RETURN_IF_ERROR( VerifyHloModule(*test_runner_.backend().platform(), module.get())); TF_ASSIGN_OR_RETURN(auto reference_module, @@ -165,9 +160,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( error); } -template <typename LiteralPtr> ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, + std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments, const optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor) { auto result = @@ -179,9 +173,8 @@ template <typename LiteralPtr> return result.ValueOrDie(); } -template <typename LiteralPtr> ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, + std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments, const optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor) { auto result = @@ -198,8 +191,14 @@ template <typename LiteralPtr> const std::function<void(HloModule*)>& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompare<std::unique_ptr<Literal>>( - std::move(module), fake_arguments, error, reference_preprocessor); + + std::vector<Literal*> fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( @@ -207,8 +206,13 @@ template <typename LiteralPtr> const std::function<void(HloModule*)>& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>( - std::move(module), fake_arguments, error, reference_preprocessor); + std::vector<Literal*> fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); + + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompare( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4d49b7071d..e375f13a44 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -104,8 +104,7 @@ class HloTestBase : public ::testing::Test { // Executes the given hlo module on two backends and compares results. // - // 'arguments': the input of the hlo module. The LiteralPtr type accepts - // Literal* or std::unique_ptr<Literal>. + // 'arguments': the input of the hlo module. // // 'error': if has value, expects the results to be near (within the error // bound). Otherwise, expects the results to be equal. @@ -114,20 +113,18 @@ class HloTestBase : public ::testing::Test { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - template <typename LiteralPtr> ::testing::AssertionResult RunAndCompare( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, + const tensorflow::gtl::ArraySlice<Literal*> arguments, const tensorflow::gtl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. - template <typename LiteralPtr> ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, + const tensorflow::gtl::ArraySlice<Literal*> arguments, const tensorflow::gtl::optional<ErrorSpec>& error, const std::function<void(HloModule*)>& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -232,10 +229,9 @@ class HloTestBase : public ::testing::Test { // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - template <typename LiteralPtr> StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, + const tensorflow::gtl::ArraySlice<Literal*> arguments, const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes, const std::function<void(HloModule*)>& reference_preprocessor); }; |