diff options
author | 2017-11-17 14:16:09 -0800 | |
---|---|---|
committer | 2017-11-17 14:20:28 -0800 | |
commit | 3f888e1539db5551cfcf9ee837a0555c224e0018 (patch) | |
tree | 5f2df45e666fc15e370e6c029bf0712ee65d53ed /tensorflow/compiler/xla/service/hlo_runner.h | |
parent | d79dd4993061670c1ec5ea01db3022f28d72d0a3 (diff) |
Add a Compiler::BuildExecutable interface that compiles the given Hlo module without optimizations.
PiperOrigin-RevId: 176158846
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_runner.h | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a5732848c6..95cddafc91 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -65,17 +65,20 @@ class HloRunner { // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr<Literal>. + // If run_hlo_passes is true, the module will be executed without Hlo + // optimization. template <typename LiteralPtr> StatusOr<std::unique_ptr<Literal>> Execute( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> literals); + const tensorflow::gtl::ArraySlice<LiteralPtr> literals, + bool run_hlo_passes = true); // Executes the given module and returns a global data handle. StatusOr<perftools::gputools::DeviceMemoryBase> Execute( std::unique_ptr<HloModule> module, tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments, - Shape* result_shape); + Shape* result_shape, bool run_hlo_passes = true); // Transfers the given literal to the device and returns the data handle. StatusOr<perftools::gputools::DeviceMemoryBase> TransferToDevice( @@ -90,7 +93,8 @@ class HloRunner { StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( std::unique_ptr<HloModule> module, tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> - arguments); + arguments, + bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. @@ -112,14 +116,15 @@ class HloRunner { template <typename LiteralPtr> StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( std::unique_ptr<HloModule> module, - const tensorflow::gtl::ArraySlice<LiteralPtr> literals) { + const tensorflow::gtl::ArraySlice<LiteralPtr> literals, + bool run_hlo_passes) { std::vector<perftools::gputools::DeviceMemoryBase> arguments; for (const auto& literal : literals) { TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, TransferToDevice(*literal)); arguments.push_back(argument); } - return ExecuteAndTransfer(std::move(module), arguments); + return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); } } // namespace xla |