aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_runner.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-03-08 23:28:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 23:32:38 -0800
commit26b83da42fb47015aabd6ba1aa8e6d41ff8763dc (patch)
tree4e6a31809bed21884584c6f99c5aeace09db91a3 /tensorflow/compiler/xla/service/hlo_runner.h
parent7dbe0cf7ecc4d0560ec9081b443ada693e4e6096 (diff)
Remove a layer of templatization
With this change - HloTestBase always calls HloRunner with an array of non-owning Literal pointers as arguments - HloRunner no longer has a general LiteralPtr, but just provides explicit overloads for std::unique_ptr<Literal> and Literal* This was prompted by a dependent change that needs to call HloTestBase::RunAndCompare with Literal* arguments. PiperOrigin-RevId: 188446331
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h37
1 files changed, 15 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index cbaebc68be..06ce22a5b9 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -64,17 +65,27 @@ class HloRunner {
const std::string& filename, const DebugOptions& debug_options);
// Executes the given module with given literals as input and returns the
- // result as a Literal. The LiteralPtr type accepts Literal* or
- // std::unique_ptr<Literal>.
+ // result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
+ const tensorflow::gtl::ArraySlice<Literal*> arguments,
bool run_hlo_passes = true);
+ StatusOr<std::unique_ptr<Literal>> Execute(
+ std::unique_ptr<HloModule> module,
+ const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ bool run_hlo_passes = true) {
+ // Construct a vector of plain pointers for the arguments.
+ std::vector<Literal*> argument_pointers;
+ c_transform(
+ arguments, std::back_inserter(argument_pointers),
+ [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ return Execute(std::move(module), argument_pointers, run_hlo_passes);
+ }
+
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
@@ -83,11 +94,6 @@ class HloRunner {
Backend& backend();
private:
- StatusOr<std::unique_ptr<Literal>> ExecuteInternal(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
- bool run_hlo_passes = true);
-
struct EigenThreadPoolWrapper;
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
@@ -95,19 +101,6 @@ class HloRunner {
std::unique_ptr<Backend> backend_;
};
-template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
- bool run_hlo_passes) {
- // Construct a vector of plain pointers for the arguments.
- std::vector<Literal*> argument_pointers;
- for (const auto& argument : arguments) {
- argument_pointers.push_back(&*argument);
- }
- return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes);
-}
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_