aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_runner.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-17 14:16:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-17 14:20:28 -0800
commit3f888e1539db5551cfcf9ee837a0555c224e0018 (patch)
tree5f2df45e666fc15e370e6c029bf0712ee65d53ed /tensorflow/compiler/xla/service/hlo_runner.h
parentd79dd4993061670c1ec5ea01db3022f28d72d0a3 (diff)
Add a Compiler::BuildExecutable interface that compiles the given Hlo module without optimizations.
PiperOrigin-RevId: 176158846
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index a5732848c6..95cddafc91 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -65,17 +65,20 @@ class HloRunner {
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
+ // If run_hlo_passes is true, the module will be executed without Hlo
+ // optimization.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<LiteralPtr> literals);
+ const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
+ bool run_hlo_passes = true);
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
- Shape* result_shape);
+ Shape* result_shape, bool run_hlo_passes = true);
// Transfers the given literal to the device and returns the data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> TransferToDevice(
@@ -90,7 +93,8 @@ class HloRunner {
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
- arguments);
+ arguments,
+ bool run_hlo_passes = true);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
@@ -112,14 +116,15 @@ class HloRunner {
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<LiteralPtr> literals) {
+ const tensorflow::gtl::ArraySlice<LiteralPtr> literals,
+ bool run_hlo_passes) {
std::vector<perftools::gputools::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument,
TransferToDevice(*literal));
arguments.push_back(argument);
}
- return ExecuteAndTransfer(std::move(module), arguments);
+ return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes);
}
} // namespace xla