diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-03-08 23:28:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-08 23:32:38 -0800 |
commit | 26b83da42fb47015aabd6ba1aa8e6d41ff8763dc (patch) | |
tree | 4e6a31809bed21884584c6f99c5aeace09db91a3 /tensorflow/compiler/xla/service/hlo_runner.h | |
parent | 7dbe0cf7ecc4d0560ec9081b443ada693e4e6096 (diff) |
Remove a layer of templatization
With this change
- HloTestBase always calls HloRunner with an array of non-owning Literal
pointers as arguments
- HloRunner no longer has a general LiteralPtr, but just provides explicit
overloads for std::unique_ptr<Literal> and Literal*
This was prompted by a dependent change that needs to call
HloTestBase::RunAndCompare with Literal* arguments.
PiperOrigin-RevId: 188446331
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_runner.h | 37 |
1 files changed, 15 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cbaebc68be..06ce22a5b9 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -64,17 +65,27 @@ class HloRunner { const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the - // result as a Literal. The LiteralPtr type accepts Literal* or - // std::unique_ptr<Literal>. + // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - template <typename LiteralPtr> StatusOr<std::unique_ptr<Literal>> Execute( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, + const tensorflow::gtl::ArraySlice<Literal*> arguments, bool run_hlo_passes = true); + StatusOr<std::unique_ptr<Literal>> Execute( + std::unique_ptr<HloModule> module, + const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments, + bool run_hlo_passes = true) { + // Construct a vector of plain pointers for the arguments. + std::vector<Literal*> argument_pointers; + c_transform( + arguments, std::back_inserter(argument_pointers), + [](const std::unique_ptr<Literal>& literal) { return literal.get(); }); + return Execute(std::move(module), argument_pointers, run_hlo_passes); + } + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -83,11 +94,6 @@ class HloRunner { Backend& backend(); private: - StatusOr<std::unique_ptr<Literal>> ExecuteInternal( - std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<Literal*> arguments, - bool run_hlo_passes = true); - struct EigenThreadPoolWrapper; std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_; @@ -95,19 +101,6 @@ class HloRunner { std::unique_ptr<Backend> backend_; }; -template <typename LiteralPtr> -StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( - std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, - bool run_hlo_passes) { - // Construct a vector of plain pointers for the arguments. - std::vector<Literal*> argument_pointers; - for (const auto& argument : arguments) { - argument_pointers.push_back(&*argument); - } - return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ |