aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/plugin
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-27 16:33:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 16:37:09 -0700
commit50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch)
tree7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/compiler/plugin
parentd6d58a3a1785785679af56c0f8f131e7312b8226 (diff)
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/compiler/plugin')
-rw-r--r--tensorflow/compiler/plugin/BUILD4
-rw-r--r--tensorflow/compiler/plugin/executor/BUILD34
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.cc122
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.h62
-rw-r--r--tensorflow/compiler/plugin/executor/device.cc60
-rw-r--r--tensorflow/compiler/plugin/executor/executable.cc147
-rw-r--r--tensorflow/compiler/plugin/executor/executable.h65
-rw-r--r--tensorflow/compiler/plugin/executor/executor.cc135
-rw-r--r--tensorflow/compiler/plugin/executor/executor.h213
-rw-r--r--tensorflow/compiler/plugin/executor/platform.cc125
-rw-r--r--tensorflow/compiler/plugin/executor/platform.h83
-rw-r--r--tensorflow/compiler/plugin/executor/platform_id.h31
-rw-r--r--tensorflow/compiler/plugin/executor/transfer_manager.cc187
-rw-r--r--tensorflow/compiler/plugin/executor/transfer_manager.h77
14 files changed, 1344 insertions, 1 deletions
diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD
index 4badd3a589..8c2e9a7c81 100644
--- a/tensorflow/compiler/plugin/BUILD
+++ b/tensorflow/compiler/plugin/BUILD
@@ -32,5 +32,7 @@ package(
cc_library(
name = "plugin",
- deps = [],
+ deps = [
+ "//tensorflow/compiler/plugin/executor:plugin_lib",
+ ],
)
diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD
new file mode 100644
index 0000000000..2e5875705f
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/BUILD
@@ -0,0 +1,34 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "plugin_lib",
+ srcs = glob([
+ "*.cc",
+ ]),
+ hdrs = glob([
+ "*.h",
+ ]),
+ deps = [
+ "//tensorflow/compiler/jit:xla_device",
+ "//tensorflow/compiler/jit:xla_jit_headers_lib",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:xla_headers_lib",
+ "//tensorflow/compiler/xla/service",
+ "//third_party/eigen3",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@protobuf//:protobuf_headers",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc
new file mode 100644
index 0000000000..3a84f08c00
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/compiler.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+#include <fstream>
+
+#include "tensorflow/compiler/plugin/executor/compiler.h"
+#include "tensorflow/compiler/plugin/executor/executable.h"
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+namespace port = ::perftools::gputools::port;
+
+namespace xla {
+namespace executorplugin {
+
+/*
+ * Run optimization passes on the module. The graph is transformed by
+ * each pass in the optimization pipeline. The service subdirectory
+ * contains useful optimization passes.
+ */
+Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) {
+ HloPassPipeline pipeline("Executor");
+ pipeline.AddPass<Inliner>();
+ pipeline.AddPass<HloSubcomputationUnification>();
+ pipeline.AddPass<HloCSE>(false);
+
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ false, [](const Shape&, const Shape&) { return false; });
+ pipeline.AddPass<ReshapeMover>();
+ pipeline.AddPass<HloConstantFolding>();
+ pipeline.AddPass<HloCSE>(true);
+
+ pipeline.AddPass<HloDCE>();
+ pipeline.AddPass<FlattenCallGraph>();
+ return pipeline.Run(hlo_module).status();
+}
+
+StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ se::StreamExecutor* stream_exec) {
+ TF_RET_CHECK(stream_exec != nullptr);
+
+ VLOG(1) << "Generate graph " << hlo_module->name();
+
+ TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
+
+ // Typically you would visit the HLO graph, building up a compiled equivalent
+ // In this case we are using an Hlo evaluator at execution time, so we don't
+ // need to compile anything
+
+ // Create executable from only the Hlo module
+ std::unique_ptr<Executable> executable;
+ executable.reset(new ExecutorExecutable(std::move(hlo_module)));
+
+ return std::move(executable);
+}
+
+StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ std::vector<se::StreamExecutor*> stream_execs) {
+
+ return tensorflow::errors::Unimplemented(
+ "Compilation of multiple HLO modules is not supported on Executor.");
+}
+
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ExecutorCompiler::CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ const AotCompilationOptions& aot_options) {
+
+ return tensorflow::errors::InvalidArgument(
+ "AOT compilation not supported on Executor");
+}
+
+se::Platform::Id ExecutorCompiler::PlatformId() const {
+ return sep::kExecutorPlatformId;
+}
+
+HloCostAnalysis::ShapeSizeFunction
+ExecutorCompiler::ShapeSizeBytesFunction() const {
+ return ExecutorExecutable::ShapeSizeBytes;
+}
+
+
+} // namespace executorplugin
+} // namespace xla
+
+REGISTER_MODULE_INITIALIZER(executor_compiler, {
+ xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() {
+ return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>();
+ });
+});
diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h
new file mode 100644
index 0000000000..d318eefc49
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/compiler.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorCompiler : public Compiler {
+ public:
+ ExecutorCompiler() {}
+ ~ExecutorCompiler() override {}
+
+ StatusOr<std::unique_ptr<Executable>> Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ perftools::gputools::StreamExecutor* stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_module,
+ std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> module,
+ const AotCompilationOptions& options) override;
+
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ private:
+ Status RunHloOptimization(HloModule* hlo_module);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc
new file mode 100644
index 0000000000..bbc39dc03f
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/device.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+const char* const DEVICE_XLA_EXEC = "XLA_EXEC";
+const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT";
+
+constexpr std::array<DataType, 5> kExecAllTypes = {
+ {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};
+
+class XlaExaDeviceFactory : public DeviceFactory {
+ public:
+ Status CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+};
+
+Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ static XlaDeviceOpRegistrations* registrations =
+ RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT);
+ (void)registrations;
+
+ std::unique_ptr<XlaDevice> device;
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0,
+ DEVICE_EXEC_XLA_JIT, options,
+ name_prefix, &device));
+ devices->push_back(device.release());
+ return Status::OK();
+}
+
+REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110);
+
+// Kernel registrations
+
+static bool OpFilter(KernelDef* kdef) { return true; }
+
+REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes);
+REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes);
+REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc
new file mode 100644
index 0000000000..79eea9af3f
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executable.cc
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/executable.h"
+#include "tensorflow/compiler/plugin/executor/executor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace xla {
+namespace executorplugin {
+
+ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
+ : Executable(std::move(hlo_module), ShapeSizeBytes) {}
+
+ExecutorExecutable::~ExecutorExecutable() {}
+
+static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor,
+ const Literal& literal) {
+ int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape()));
+ void* buf = executor->Allocate(size);
+ const void* src = literal.InternalData();
+ memcpy(buf, src, size);
+ return se::DeviceMemoryBase(buf, size);
+}
+
+static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor,
+ const Literal& literal) {
+ const Shape& shape = literal.shape();
+ if (shape.element_type() != xla::TUPLE) {
+ return AllocateSingleOutput(executor, literal);
+ } else {
+ int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*)));
+ void** buf = reinterpret_cast<void**>(executor->Allocate(size));
+ for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) {
+ se::DeviceMemoryBase out =
+ AllocateSingleOutput(executor, literal.tuple_literals(n));
+ *buf++ = out.opaque();
+ }
+
+ return se::DeviceMemoryBase(buf, size);
+ }
+}
+
+StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+
+ VLOG(1) << "Execute " << module().name();
+ if (VLOG_IS_ON(2)) {
+ for (const auto& a : arguments) {
+ VLOG(2) << "-- argument " << a.opaque();
+ }
+ }
+
+ uint64 start_micros = tensorflow::Env::Default()->NowMicros();
+
+ HloComputation* computation = module().entry_computation();
+ if (computation->num_parameters() != arguments.size()) {
+ return tensorflow::errors::Internal(
+ "Mismatch between argument count and graph parameter count.");
+ }
+
+ // Create the arguments as an vector of XLA literals
+ std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal*> arg_literals_ptrs;
+ for (int64 p = 0; p < computation->num_parameters(); p++) {
+ // Create the input literal for the parameter
+ HloInstruction* param = computation->parameter_instruction(p);
+ arg_literals.emplace_back(Literal::CreateFromShape(param->shape()));
+ arg_literals_ptrs.push_back(arg_literals.back().get());
+
+ // Copy in the data from the stream_executor buffers
+ void* buffer = arg_literals.back().get()->MutableInternalData();
+ memcpy(buffer, arguments[p].opaque(),
+ ShapeUtil::ByteSizeOf(param->shape()));
+ }
+
+ // Execute the graph using the evaluator
+ HloEvaluator evaluator;
+ std::unique_ptr<Literal> output;
+ TF_ASSIGN_OR_RETURN(output,
+ evaluator.Evaluate(computation, arg_literals_ptrs));
+
+ // Copy the result into the return buffer
+ perftools::gputools::StreamExecutor* executor(stream->parent());
+ sep::ExecutorExecutor* executorExecutor(
+ static_cast<sep::ExecutorExecutor*>(executor->implementation()));
+
+ se::DeviceMemoryBase ret =
+ AllocateOutputBuffer(executorExecutor, *(output.get()));
+
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ const double nanoseconds = (end_micros - start_micros) * 1000.0;
+ execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+ }
+
+ return ret;
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> ExecutorExecutable::ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ return tensorflow::errors::Unimplemented(
+ "ExecuteOnStream is not yet supported on Executor.");
+}
+
+StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ return tensorflow::errors::Unimplemented(
+ "ExecuteAsyncOnStream is not yet supported on Executor.");
+}
+
+/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) {
+ if (ShapeUtil::IsOpaque(shape)) {
+ return sizeof(void*);
+ }
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+
+} // namespace executorplugin
+} // namespace xla
diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h
new file mode 100644
index 0000000000..ba3d4da21d
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executable.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorExecutable : public Executable {
+ public:
+ ExecutorExecutable(std::unique_ptr<HloModule> hlo_module);
+ ~ExecutorExecutable() override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) override;
+
+ static int64 ShapeSizeBytes(const Shape& shape);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc
new file mode 100644
index 0000000000..e72c2711f7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executor.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/executor.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+namespace se = ::perftools::gputools;
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+host::HostStream *AsExecutorStream(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return dynamic_cast<host::HostStream *>(stream->implementation());
+}
+
+ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config)
+ : plugin_config_(plugin_config) {}
+
+ExecutorExecutor::~ExecutorExecutor() {}
+
+void *ExecutorExecutor::Allocate(uint64 size) {
+ void *buf = new char[size];
+ return buf;
+}
+
+void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent,
+ uint64 offset_bytes,
+ uint64 size_bytes) {
+ return parent + offset_bytes;
+}
+
+void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) {
+ if (!mem->is_sub_buffer()) {
+ delete[] static_cast<char *>(mem->opaque());
+ }
+}
+
+bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst,
+ const DeviceMemoryBase &dev_src, uint64 size) {
+ AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
+ port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
+ });
+ return true;
+}
+
+bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
+ const void *host_src, uint64 size) {
+ AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
+ port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
+ });
+ return true;
+}
+
+port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst,
+ const void *host_src,
+ uint64 size) {
+ memcpy(dev_dst->opaque(), host_src, size);
+ return port::Status::OK();
+}
+
+port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &dev_src,
+ uint64 size) {
+ memcpy(host_dst, dev_src.opaque(), size);
+ return port::Status::OK();
+}
+
+bool ExecutorExecutor::HostCallback(Stream *stream,
+ std::function<void()> callback) {
+ AsExecutorStream(stream)->EnqueueTask(callback);
+ return true;
+}
+
+bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
+ AsExecutorStream(dependent)->EnqueueTask(
+ [other]() { other->BlockHostUntilDone(); });
+ AsExecutorStream(dependent)->BlockUntilDone();
+ return true;
+}
+
+bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) {
+ dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream);
+ return true;
+}
+
+bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) {
+ dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream);
+ return true;
+}
+
+bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) {
+ AsExecutorStream(stream)->BlockUntilDone();
+ return true;
+}
+
+DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const {
+ internal::DeviceDescriptionBuilder builder;
+
+ builder.set_device_address_bits(64);
+
+ builder.set_name("Executor");
+ builder.set_device_vendor("VectorName");
+ builder.set_platform_version("1.0");
+ builder.set_driver_version("1.0");
+ builder.set_runtime_version("1.0");
+ builder.set_pci_bus_id("1");
+ builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
+ builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9);
+
+ auto built = builder.Build();
+ return built.release();
+}
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/plugin/executor/executor.h
new file mode 100644
index 0000000000..32fdb157e4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executor.h
@@ -0,0 +1,213 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declares the ExecutorExecutor class, which is a CPU-only implementation of
+// the StreamExecutor interface. For now, this is used for testing and to
+// examine the performance of host-based StreamExecutor code.
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+
+#include "tensorflow/stream_executor/host/host_stream.h"
+#include "tensorflow/stream_executor/host/host_timer.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+#include <list>
+#include <mutex>
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+
+class ExecutorExecutor : public internal::StreamExecutorInterface {
+ public:
+ explicit ExecutorExecutor(const PluginConfig &plugin_config);
+ ~ExecutorExecutor() override;
+
+ port::Status Init(int device_ordinal, DeviceOptions device_options) override {
+ return port::Status::OK();
+ }
+
+ bool GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) override {
+ return false;
+ }
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const KernelArgsArrayBase &args) override {
+ return false;
+ }
+
+ void *Allocate(uint64 size) override;
+ void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
+ uint64 size_bytes) override;
+ void Deallocate(DeviceMemoryBase *mem) override;
+
+ void *HostMemoryAllocate(uint64 size) override { return new char[size]; }
+ void HostMemoryDeallocate(void *mem) override {
+ delete[] static_cast<char *>(mem);
+ }
+ bool HostMemoryRegister(void *mem, uint64 size) override { return true; }
+ bool HostMemoryUnregister(void *mem) override { return true; }
+
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src,
+ uint64 size) override;
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &host_src,
+ uint64 size) override {
+ return false;
+ }
+
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) override {
+ return false;
+ }
+
+ // No "synchronize all activity" implemented for this platform at the moment.
+ bool SynchronizeAllActivity() override { return false; }
+ bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
+ return false;
+ }
+
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) override {
+ return false;
+ }
+
+ port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst,
+ const void *host_src, uint64 size) override;
+ port::Status SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ bool HostCallback(Stream *stream, std::function<void()> callback) override;
+
+ port::Status AllocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status DeallocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status RecordEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status WaitForEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ Event::Status PollForEventStatus(Event *event) override {
+ return Event::Status::kError;
+ }
+
+ bool AllocateStream(Stream *stream) override { return true; }
+ void DeallocateStream(Stream *stream) override {}
+ bool CreateStreamDependency(Stream *dependent, Stream *other) override;
+
+ bool AllocateTimer(Timer *timer) override { return true; }
+ void DeallocateTimer(Timer *timer) override {}
+ bool StartTimer(Stream *stream, Timer *timer) override;
+ bool StopTimer(Stream *stream, Timer *timer) override;
+
+ bool BlockHostUntilDone(Stream *stream) override;
+
+ int PlatformDeviceCount() override { return 1; }
+
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const override {
+ return false;
+ }
+
+ DeviceDescription *PopulateDeviceDescription() const override;
+
+ port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return port::Status::OK();
+ }
+
+ bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return true;
+ }
+
+ SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
+ return SharedMemoryConfig::kDefault;
+ }
+
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "Shared memory not supported"};
+ }
+
+ std::unique_ptr<internal::EventInterface> CreateEventImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
+ override {
+ return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
+ }
+
+ std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
+ return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
+ }
+
+ port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape,
+ Args args);
+
+ private:
+ DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
+
+ port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
+ const xla::Shape &shape);
+
+ const PluginConfig plugin_config_;
+};
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
diff --git a/tensorflow/compiler/plugin/executor/platform.cc b/tensorflow/compiler/plugin/executor/platform.cc
new file mode 100644
index 0000000000..2f339f04a7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/platform.h"
+#include "tensorflow/compiler/plugin/executor/executor.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/status_macros.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+PLATFORM_DEFINE_ID(kExecutorPlatformId);
+
+ExecutorPlatform::ExecutorPlatform() : name_("Executor") {}
+
+ExecutorPlatform::~ExecutorPlatform() {}
+
+Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; }
+
+int ExecutorPlatform::VisibleDeviceCount() const { return 1; }
+
+const string& ExecutorPlatform::Name() const { return name_; }
+
+port::StatusOr<StreamExecutor*> ExecutorPlatform::ExecutorForDevice(
+ int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ config.plugin_config = PluginConfig();
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*>
+ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
+ int device_ordinal, const PluginConfig& plugin_config) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ config.plugin_config = plugin_config;
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor(
+ const StreamExecutorConfig& config) {
+ mutex_lock lock(executors_mutex_);
+
+ port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
+ if (status.ok()) {
+ return status.ValueOrDie();
+ }
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
+ GetUncachedExecutor(config);
+ if (!executor.ok()) {
+ return executor.status();
+ }
+
+ StreamExecutor* naked_executor = executor.ValueOrDie().get();
+ SE_RETURN_IF_ERROR(
+ executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
+ return naked_executor;
+}
+
+port::StatusOr<std::unique_ptr<StreamExecutor>>
+ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
+ auto executor = port::MakeUnique<StreamExecutor>(
+ this, port::MakeUnique<ExecutorExecutor>(config.plugin_config));
+ auto init_status = executor->Init(config.ordinal, config.device_options);
+ if (!init_status.ok()) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf(
+ "failed initializing StreamExecutor for device ordinal %d: %s",
+ config.ordinal, init_status.ToString().c_str())};
+ }
+
+ return std::move(executor);
+}
+
+void ExecutorPlatform::RegisterTraceListener(
+ std::unique_ptr<TraceListener> listener) {
+ LOG(FATAL) << "not yet implemented: register executor trace listener";
+}
+
+void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) {
+ LOG(FATAL) << "not yet implemented: unregister executor trace listener";
+}
+
+static void InitializeExecutorPlatform() {
+ std::unique_ptr<se::Platform> platform(new sep::ExecutorPlatform);
+ SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform)));
+}
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform());
+
+DECLARE_MODULE_INITIALIZER(multi_platform_manager);
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager);
diff --git a/tensorflow/compiler/plugin/executor/platform.h b/tensorflow/compiler/plugin/executor/platform.h
new file mode 100644
index 0000000000..c252a589d4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform.h
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/stream_executor/executor_cache.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+class ExecutorPlatform : public Platform {
+ public:
+ ExecutorPlatform();
+ ~ExecutorPlatform() override;
+
+ Platform::Id id() const override;
+
+ // Device count is less clear-cut for CPUs than accelerators. This call
+ // currently returns the number of thread units in the host, as reported by
+ // base::NumCPUs().
+ int VisibleDeviceCount() const override;
+
+ const string& Name() const override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
+ int ordinal, const PluginConfig& config) override;
+
+ port::StatusOr<StreamExecutor*> GetExecutor(
+ const StreamExecutorConfig& config) override;
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config) override;
+
+ void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override;
+
+ void UnregisterTraceListener(TraceListener* listener) override;
+
+ private:
+ // This platform's name.
+ string name_;
+
+ // mutex that guards the ordinal-to-executor map.
+ mutable mutex executors_mutex_;
+
+ // Cache of created StreamExecutors.
+ ExecutorCache executor_cache_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform);
+};
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
diff --git a/tensorflow/compiler/plugin/executor/platform_id.h b/tensorflow/compiler/plugin/executor/platform_id.h
new file mode 100644
index 0000000000..8d2b29a3e4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform_id.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
+#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
+
+#include "tensorflow/stream_executor/platform.h"
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+extern const Platform::Id kExecutorPlatformId;
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc
new file mode 100644
index 0000000000..51c5deeea5
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/transfer_manager.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace xla {
+namespace executorplugin {
+
+ExecutorTransferManager::ExecutorTransferManager() {}
+
+se::Platform::Id ExecutorTransferManager::PlatformId() const {
+ return se::executorplugin::kExecutorPlatformId;
+}
+
+Status ExecutorTransferManager::TransferLiteralFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
+ TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
+
+ // Tuples are a special case and contain one or more shapes inside of them to
+ // an arbitrary nesting depth.
+ if (device_shape.element_type() == TUPLE) {
+ *literal->mutable_shape() = literal_shape;
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ ShallowCopyTupleFromDevice(executor, source, device_shape));
+ TF_RET_CHECK(element_buffers.size() ==
+ ShapeUtil::TupleElementCount(device_shape));
+ for (int64 i = 0; i < element_buffers.size(); ++i) {
+ const Shape& element_device_shape = device_shape.tuple_shapes(i);
+ const Shape& element_literal_shape = literal_shape.tuple_shapes(i);
+ Literal* element_literal = literal->add_tuple_literals();
+ // Recursively call TransferFromDevice to copy over the data in the
+ // element array.
+ TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
+ executor, element_buffers[i], element_device_shape,
+ element_literal_shape, element_literal));
+ }
+ return Status::OK();
+ }
+
+ *literal->mutable_shape() = device_shape;
+ literal->Reserve(ShapeUtil::ElementsIn(device_shape));
+ TF_RETURN_IF_ERROR(TransferBufferFromDevice(
+ executor, source, ShapeUtil::ByteSizeOf(device_shape),
+ literal->MutableInternalData()));
+ if (!ShapeUtil::Equal(literal_shape, device_shape)) {
+ literal->Swap(
+ literal->Relayout(literal_shape.layout()).get());
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
+ return Status::OK();
+}
+
+StatusOr<std::vector<se::DeviceMemoryBase>>
+ExecutorTransferManager::ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(shape));
+
+ std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
+ nullptr);
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+ auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
+ element_pointers.data());
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
+ }
+
+ // Create a DeviceMemoryBase from each void* pointer.
+ std::vector<se::DeviceMemoryBase> destination;
+ for (int i = 0; i < element_pointers.size(); ++i) {
+ if (element_pointers[i] == nullptr &&
+ !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
+ return FailedPrecondition("tuple contains nullptr at element %d", i);
+ }
+ int64 buffer_size =
+ ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*));
+ destination.emplace_back(element_pointers[i], buffer_size);
+ }
+ return std::move(destination);
+}
+
+Status ExecutorTransferManager::TransferLiteralToDevice(
+ se::StreamExecutor* executor, const Literal& literal,
+ se::DeviceMemoryBase* destination) {
+ const Shape& shape = literal.shape();
+
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<void*> tuple_elements_on_device;
+ for (const Literal& tuple_element : literal.tuple_literals()) {
+ se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
+ GetByteSizeRequirement(tuple_element.shape()));
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDevice(executor, tuple_element, &allocation));
+ tuple_elements_on_device.push_back(allocation.opaque());
+ }
+ return TransferBufferToDevice(
+ executor, tuple_elements_on_device.size() * sizeof(void*),
+ tuple_elements_on_device.data(), destination);
+ }
+
+ return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
+ literal.InternalData(),
+ destination);
+}
+
+Status ExecutorTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const Literal& literal) {
+ const Shape& shape = literal.shape();
+ VLOG(1) << "transferring literal shape to infeed: "
+ << ShapeUtil::HumanString(shape);
+
+ return Status::OK();
+}
+
+Status ExecutorTransferManager::TransferBufferToInfeed(
+ se::StreamExecutor* executor, int64 size, const void* source) {
+ return Unimplemented("Transfer to Infeed");
+}
+
+Status ExecutorTransferManager::TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) {
+ const Shape& shape = literal->shape();
+ VLOG(1) << "transferring literal shape from outfeed: "
+ << ShapeUtil::HumanString(shape);
+
+ return Status::OK();
+}
+
+Status ExecutorTransferManager::ResetDevices(
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executors) {
+ return Unimplemented("Device reset not supported");
+}
+
+int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+} // namespace executorplugin
+} // namespace xla
+
+static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
+ return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId,
+ &CreateExecutorTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h
new file mode 100644
index 0000000000..7a42e5a2d7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+#include <vector>
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorTransferManager : public TransferManager {
+ public:
+ ExecutorTransferManager();
+
+ ~ExecutorTransferManager() override {}
+
+ se::Platform::Id PlatformId() const override;
+
+ StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) override;
+
+ Status TransferLiteralFromDevice(se::StreamExecutor* executor,
+ const se::DeviceMemoryBase& source,
+ const Shape& device_shape,
+ const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status TransferLiteralToDevice(se::StreamExecutor* executor,
+ const Literal& literal,
+ se::DeviceMemoryBase* destination) override;
+
+ Status TransferLiteralToInfeed(se::StreamExecutor* executor,
+ const Literal& literal) override;
+
+ Status TransferBufferToInfeed(se::StreamExecutor* executor,
+ int64 size, const void* source) override;
+
+ Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+ const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status ResetDevices(
+ tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+
+ int64 GetByteSizeRequirement(const Shape& shape) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_