aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_runner.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 10:14:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 10:16:34 -0700
commitffbf77de81d0b7b4b169c92d0d9fbbdef5b8842a (patch)
tree8d68eedf28bdcac55516b6a3a176d56f6cef0fa2 /tensorflow/compiler/xla/service/hlo_runner.cc
parent8a247976484173059aedc17bfd8d770b8d1a70e1 (diff)
Introduced tool to run an HLO module in replicated fashion, by infeeding random data and outfeeding the data generated at each step.
The arguments of the computation can be either read from the session module, or randomly generated. The tool uses the raw transfer manager API to infeed and outfeed the data. PiperOrigin-RevId: 192628605
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc189
1 files changed, 150 insertions, 39 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index ec7d8210a7..2e834a79d9 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -16,21 +16,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_runner.h"
-#include <set>
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/service/backend.h"
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
-#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -91,15 +86,6 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
return tools::Parse(hlo_string, config);
}
-// Define this in .cc file to avoid having to include eigen or forward declare
-// these types in the header.
-struct HloRunner::EigenThreadPoolWrapper {
- std::unique_ptr<EigenThreadPoolWrapper> pool;
- std::unique_ptr<Eigen::ThreadPoolDevice> device;
-};
-
-HloRunner::HloRunner() {}
-
HloRunner::HloRunner(se::Platform* platform) {
BackendOptions backend_options;
backend_options.set_platform(platform);
@@ -113,32 +99,14 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<Literal*> arguments,
bool run_hlo_passes) {
- if (run_hlo_passes) {
- TF_ASSIGN_OR_RETURN(
- module, backend().compiler()->RunHloPasses(
- std::move(module), backend().default_stream_executor(),
- /*device_allocator=*/nullptr));
- }
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Executable> executable,
- backend().compiler()->RunBackend(std::move(module),
- backend().default_stream_executor(),
- /*device_allocator=*/nullptr));
-
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+ CreateExecutable(std::move(module), run_hlo_passes));
se::Stream stream(backend().default_stream_executor());
stream.Init();
- ExecutableRunOptions run_options;
- run_options.set_device_ordinal(backend().default_device_ordinal());
- run_options.set_stream(&stream);
- run_options.set_allocator(backend().memory_allocator());
- run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
- run_options.set_intra_op_thread_pool(
- backend().eigen_intra_op_thread_pool_device());
-
- ServiceExecutableRunOptions service_run_options(
- run_options, backend().StreamBorrower(),
- backend().inter_op_thread_pool());
+ ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice(
+ backend().default_device_ordinal(), &stream, nullptr));
+ const ExecutableRunOptions& run_options = service_run_options.run_options();
// Copy arguments to device.
std::vector<std::unique_ptr<ScopedShapedBuffer>> argument_buffers;
@@ -178,10 +146,153 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return result_literal;
}
+StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+ std::unique_ptr<HloModule> module,
+ const ReplicatedExecuteOptions& options) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ CreateExecutable(std::move(module), options.run_hlo_passes));
+ TF_ASSIGN_OR_RETURN(
+ DeviceAssignment device_assignment,
+ backend().computation_placer()->AssignDevices(options.num_replicas, 1));
+ std::vector<std::unique_ptr<se::Stream>> streams;
+ std::vector<ServiceExecutableRunOptions> service_run_options;
+ std::vector<std::unique_ptr<ScopedShapedBuffer>> argument_buffers;
+ // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are
+ // no arguments.
+ std::vector<const ShapedBuffer*> argument_buffer_ptrs(
+ options.num_replicas * options.arguments.size() + 1);
+ std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
+ argument_buffer_slices;
+ int64 index = 0;
+ for (int64 i = 0; i < options.num_replicas; ++i) {
+ int64 device = device_assignment(i, 0);
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
+ backend().stream_executor(device));
+ streams.push_back(absl::make_unique<se::Stream>(executor));
+ streams.back()->Init();
+ service_run_options.emplace_back(GetServiceRunOptionsForDevice(
+ device, streams.back().get(), &device_assignment));
+
+ // Copy arguments to device.
+ for (const Literal* argument : options.arguments) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<ScopedShapedBuffer> argument_buffer,
+ backend().transfer_manager()->AllocateScopedShapedBuffer(
+ argument->shape(), backend().memory_allocator(), device));
+ TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
+ executor, *argument, *argument_buffer));
+ argument_buffers.push_back(std::move(argument_buffer));
+ argument_buffer_ptrs[index++] = argument_buffers.back().get();
+ }
+ argument_buffer_slices.emplace_back(
+ &argument_buffer_ptrs[index - options.arguments.size()],
+ options.arguments.size());
+ }
+
+ std::unique_ptr<tensorflow::thread::ThreadPool> pool;
+ int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0;
+ if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
+ num_threads += options.num_replicas;
+ }
+ if (num_threads > 0) {
+ pool = absl::make_unique<tensorflow::thread::ThreadPool>(
+ tensorflow::Env::Default(), "infeed_outfeed",
+ /*num_threads=*/num_threads);
+ }
+ if (options.infeed != nullptr) {
+ for (int64 i = 0; i < options.num_replicas; ++i) {
+ int64 device = device_assignment(i, 0);
+ pool->Schedule([this, device, &options]() {
+ se::StreamExecutor* executor =
+ backend().stream_executor(device).ValueOrDie();
+ VLOG(1) << "Starting infeed on device " << device;
+ for (int64 step = 1;
+ options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
+ TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed(
+ executor, *options.infeed));
+ if (step % 100 == 0) {
+ VLOG(1) << "Infeed step " << step;
+ }
+ }
+ });
+ }
+ }
+ if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
+ for (int64 i = 0; i < options.num_replicas; ++i) {
+ int64 device = device_assignment(i, 0);
+ pool->Schedule([this, device, &options]() {
+ se::StreamExecutor* executor =
+ backend().stream_executor(device).ValueOrDie();
+ VLOG(1) << "Starting outfeed on device " << device;
+ for (int64 step = 1;
+ options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
+ auto literal = absl::make_unique<Literal>();
+ TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
+ executor, options.outfeed_shape, literal.get()));
+ if (options.outfeed_values != nullptr) {
+ options.outfeed_values->push_back(std::move(literal));
+ }
+ if (step % 100 == 0) {
+ VLOG(1) << "Outfeed step " << step;
+ }
+ }
+ });
+ }
+ }
+
+ LOG(INFO) << "Replicated execution started";
+ TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<ShapedBuffer>> results,
+ executable->ExecuteOnStreams(service_run_options,
+ argument_buffer_slices));
+ LOG(INFO) << "Replicated execution terminated";
+
+ std::vector<std::unique_ptr<Literal>> exec_results;
+ for (int64 i = 0; i < options.num_replicas; ++i) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<ScopedShapedBuffer> result,
+ ScopedShapedBuffer::MakeScoped(
+ results[i].get(), backend().memory_allocator()));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ backend().transfer_manager()->TransferLiteralFromDevice(
+ streams[i]->parent(), *result));
+ exec_results.push_back(std::move(literal));
+ }
+ return std::move(exec_results);
+}
+
+StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
+ std::unique_ptr<HloModule> module, bool run_hlo_passes) {
+ if (run_hlo_passes) {
+ TF_ASSIGN_OR_RETURN(
+ module, backend().compiler()->RunHloPasses(
+ std::move(module), backend().default_stream_executor(),
+ backend().memory_allocator()));
+ }
+ return backend().compiler()->RunBackend(std::move(module),
+ backend().default_stream_executor(),
+ backend().memory_allocator());
+}
+
+ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
+ int64 device, se::Stream* stream, DeviceAssignment* device_assignment) {
+ ExecutableRunOptions run_options;
+ run_options.set_device_ordinal(device);
+ run_options.set_stream(stream);
+ run_options.set_allocator(backend().memory_allocator());
+ run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
+ run_options.set_intra_op_thread_pool(
+ backend().eigen_intra_op_thread_pool_device());
+ if (device_assignment != nullptr) {
+ run_options.set_device_assignment(device_assignment);
+ }
+ return ServiceExecutableRunOptions(run_options, backend().StreamBorrower(),
+ backend().inter_op_thread_pool());
+}
+
Backend& HloRunner::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
- VLOG(1) << "executing on platform " << backend().platform()->Name();
+ VLOG(1) << "Executing on platform " << backend().platform()->Name();
}
return *backend_;
}