diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-26 12:54:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-26 12:57:46 -0700 |
commit | f3c89936e97c99dead1ca3310246691c1b221adf (patch) | |
tree | 3c99b66936ed59028b32609115a239f52798907d /tensorflow/compiler/plugin | |
parent | 0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff) |
Merge changes from github.
END_PUBLIC
Note: this CL will break builds. cl/159887762 to follow to fix all the breakages.
---
Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com>
Committed by gunan<gunan@google.com>:
Updated link to use HTTPS (#10998)
Howdy!
I just updated a link to use https instead of http.
Thanks!
---
Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] Fixes run_metadata_test for SYCL
This test is designed to test CUDA specific behavior
---
Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Update comments
---
Commit 1699d904a authored by John Lawson<john@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] Fixes CUDA specific test run on SYCL (#56)
The testBadParentValuesOnGPU should only be run on CUDA devices, as the
test checks for particular CUDA behaviour. We don't actually provide a
SYCL kernel for GatherTree and so it's not a problem that the tests
don't target SYCL.
---
Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com>
Committed by Shanqing Cai<cais@google.com>:
Java API to get the size of specified input list of operations. (#10865)
* Java API to get the size of specified input list of operations
* remove unnecessary explain to avoid bring a new term to users.
---
Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] REGISTER -> REGISTER6
---
Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com>
Committed by superryanguo<superryanguo@gmail.com>:
Simplify the Quickstart section with the weblink is better
---
Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr>
Committed by Taehoon Lee<taehoonlee@snu.ac.kr>:
Fix typos
---
Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com>
Committed by GitHub<noreply@github.com>:
Correct the learning rate as per the code snippet
---
Commit 03da61134 authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Update ir_array.cc
---
Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Another try
---
Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Transpose to go through Eigen (#10321)
---
Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848)
* [OpenCL] Added RGBToHSV and HSVToRGB
* Aligning '\'
---
Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Registers AdjustContrastv2 (#10949)
* [OpenCL] Registers AdjustContrastv2 (#93)
* [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96)
* [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL
* simplified to #ifndef
* Changed to "#if GOOGLE_CUDA"
* Update adjust_contrast_op_benchmark_test.cc
* Added comments
---
Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Make TransferBufferToInFeed not virual so it compiles.
---
Commit e89f04d80 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix calling Literal member functions.
---
Commit 15a8df724 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix mac build
clone from meheff's change:
[XLA] Change return type of DeviceAssignment::Deserialize to fix build
breakage on mac.
The mac build had the following error:
error: incomplete type 'xla::DeviceAssignment' used in type trait
expression
This was due to a static method returning a StatusOr<DeviceAssignment>
inside of the definition of DeviceAssignment.
---
Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Replace LiteralUtil to Literal in compiler/plugin/executor
---
Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com>
Committed by Guenther Schmuelling<guschmue@microsoft.com>:
expand inline for debug builds to limit number of symbols
---
Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix visibility error for contrib/remote_fused_graph/pylib/BUILD.
---
Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org>
Committed by Mark Neumann<markn@allenai.org>:
fix initial test values to avoid numerical instability
---
Commit b58d98353 authored by sj6077<epik03sj@gmail.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
Fixes of AutoParallel bug (#10368)
* Fix the bug that auto_parallel could replicate variable snapshot name
* Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item
* remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel
---
Commit a286b7db8 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Make debug_test slice integer.
---
Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com>
Committed by GitHub<noreply@github.com>:
Fixed path to seq2seq.py and minor formatting
---
Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com>
Committed by Anish Shah<shah.anish07@gmail.com>:
Improve docs for tf.nn.depthwise_conv2d_native
---
Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com>
Committed by Yong Tang<yong.tang.github@outlook.com>:
Fix mismatched delete in mkl_tfconv_op.cc
This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc
(the file went through clang-format so there are some additional
changes)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com>
Committed by Danny Goodman<goodman.danny@gmail.com>:
fix error format
---
Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible).
PiperOrigin-RevId: 159649743
---
Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Add tests for select ops and while loops that produce tuples that contain predicates.
PiperOrigin-RevId: 159645900
---
Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use C API to implement Operation.name property
This name property is used in many existing tests including those that
already run with C API enabled (math_ops_test, framework_ops_test,
session_test, session_partial_run_test, math_ops_test_gpu, etc).
PiperOrigin-RevId: 159645767
---
Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error.
PiperOrigin-RevId: 159645612
---
Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 159570549
PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/compiler/plugin')
-rw-r--r-- | tensorflow/compiler/plugin/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/compiler.cc | 123 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/compiler.h | 64 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/device.cc | 60 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executable.cc | 147 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executable.h | 65 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executor.cc | 135 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executor.h | 213 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/platform.cc | 125 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/platform.h | 83 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/platform_id.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/transfer_manager.cc | 187 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/transfer_manager.h | 77 |
14 files changed, 1345 insertions, 1 deletions
diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index 4badd3a589..8c2e9a7c81 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -32,5 +32,7 @@ package( cc_library( name = "plugin", - deps = [], + deps = [ + "//tensorflow/compiler/plugin/executor:plugin_lib", + ], ) diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD new file mode 100644 index 0000000000..9bc706abdf --- /dev/null +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -0,0 +1,32 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "plugin_lib", + srcs = glob([ + "*.cc", + ]), + hdrs = glob([ + "*.h", + ]), + deps = [ + "//tensorflow/compiler/jit:xla_jit_headers_lib", + "//tensorflow/compiler/xla:xla_headers_lib", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//third_party/eigen3", + "@local_config_cuda//cuda:cuda_headers", + "@protobuf//:protobuf_headers", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc new file mode 100644 index 0000000000..893ff152f0 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdlib.h> +#include <fstream> + +#include "tensorflow/compiler/plugin/executor/compiler.h" +#include "tensorflow/compiler/plugin/executor/executable.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/status_macros.h" + +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/strcat.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; +namespace port = ::perftools::gputools::port; + +namespace xla { +namespace executorplugin { + +/* + * Run optimization passes on the module. The graph is transformed by + * each pass in the optimization pipeline. The service subdirectory + * contains useful optimization passes. + */ +Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, + HloDumper dump_hlo) { + HloPassPipeline pipeline("Executor", dump_hlo); + pipeline.AddPass<Inliner>(); + pipeline.AddPass<HloSubcomputationUnification>(); + pipeline.AddPass<HloCSE>(false); + + pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( + false, [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass<ReshapeMover>(); + pipeline.AddPass<HloConstantFolding>(); + pipeline.AddPass<HloCSE>(true); + + pipeline.AddPass<HloDCE>(); + pipeline.AddPass<FlattenCallGraph>(); + return pipeline.Run(hlo_module).status(); +} + +StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile( + std::unique_ptr<HloModule> hlo_module, HloDumper dump_hlo, + se::StreamExecutor* stream_exec) { + TF_RET_CHECK(stream_exec != nullptr); + + VLOG(1) << "Generate graph " << hlo_module->name(); + + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo)); + + // Typically you would visit the HLO graph, building up a compiled equivalent + // In this case we are using an Hlo evaluator at execution time, so we don't + // need to compile anything + + // Create executable from only the Hlo module + std::unique_ptr<Executable> executable; + executable.reset(new ExecutorExecutable(std::move(hlo_module))); + + return std::move(executable); +} + +StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile( + std::vector<std::unique_ptr<HloModule>> hlo_modules, + HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) { + + return tensorflow::errors::Unimplemented( + "Compilation of multiple HLO modules is not supported on Executor."); +} + +StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> +ExecutorCompiler::CompileAheadOfTime( + std::vector<std::unique_ptr<HloModule>> hlo_modules, + HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + + return tensorflow::errors::InvalidArgument( + "AOT compilation not supported on Executor"); +} + +se::Platform::Id ExecutorCompiler::PlatformId() const { + return sep::kExecutorPlatformId; +} + +HloCostAnalysis::ShapeSizeFunction +ExecutorCompiler::ShapeSizeBytesFunction() const { + return ExecutorExecutable::ShapeSizeBytes; +} + + +} // namespace executorplugin +} // namespace xla + +REGISTER_MODULE_INITIALIZER(executor_compiler, { + xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { + return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>(); + }); +}); diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h new file mode 100644 index 0000000000..8fe591c8ab --- /dev/null +++ b/tensorflow/compiler/plugin/executor/compiler.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ + +#include <memory> + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +namespace xla { +namespace executorplugin { + +class ExecutorCompiler : public Compiler { + public: + ExecutorCompiler() {} + ~ExecutorCompiler() override {} + + StatusOr<std::unique_ptr<Executable>> Compile( + std::unique_ptr<HloModule> hlo_module, + HloDumper dump_hlo, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( + std::vector<std::unique_ptr<HloModule>> hlo_module, + HloDumper dump_hlo, + std::vector<perftools::gputools::StreamExecutor*> stream_exec) override; + + StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> + CompileAheadOfTime( + std::vector<std::unique_ptr<HloModule>> module, + HloDumper dump_hlo, const AotCompilationOptions& options) override; + + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; + + perftools::gputools::Platform::Id PlatformId() const override; + + private: + Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo); + + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc new file mode 100644 index 0000000000..bbc39dc03f --- /dev/null +++ b/tensorflow/compiler/plugin/executor/device.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +const char* const DEVICE_XLA_EXEC = "XLA_EXEC"; +const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT"; + +constexpr std::array<DataType, 5> kExecAllTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; + +class XlaExaDeviceFactory : public DeviceFactory { + public: + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector<Device*>* devices) override; +}; + +Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector<Device*>* devices) { + static XlaDeviceOpRegistrations* registrations = + RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT); + (void)registrations; + + std::unique_ptr<XlaDevice> device; + TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0, + DEVICE_EXEC_XLA_JIT, options, + name_prefix, &device)); + devices->push_back(device.release()); + return Status::OK(); +} + +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110); + +// Kernel registrations + +static bool OpFilter(KernelDef* kdef) { return true; } + +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes); +REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter); + +} // namespace tensorflow diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc new file mode 100644 index 0000000000..79eea9af3f --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/executable.h" +#include "tensorflow/compiler/plugin/executor/executor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + +namespace xla { +namespace executorplugin { + +ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module) + : Executable(std::move(hlo_module), ShapeSizeBytes) {} + +ExecutorExecutable::~ExecutorExecutable() {} + +static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor, + const Literal& literal) { + int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); + void* buf = executor->Allocate(size); + const void* src = literal.InternalData(); + memcpy(buf, src, size); + return se::DeviceMemoryBase(buf, size); +} + +static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor, + const Literal& literal) { + const Shape& shape = literal.shape(); + if (shape.element_type() != xla::TUPLE) { + return AllocateSingleOutput(executor, literal); + } else { + int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); + void** buf = reinterpret_cast<void**>(executor->Allocate(size)); + for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { + se::DeviceMemoryBase out = + AllocateSingleOutput(executor, literal.tuple_literals(n)); + *buf++ = out.opaque(); + } + + return se::DeviceMemoryBase(buf, size); + } +} + +StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + + VLOG(1) << "Execute " << module().name(); + if (VLOG_IS_ON(2)) { + for (const auto& a : arguments) { + VLOG(2) << "-- argument " << a.opaque(); + } + } + + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + HloComputation* computation = module().entry_computation(); + if (computation->num_parameters() != arguments.size()) { + return tensorflow::errors::Internal( + "Mismatch between argument count and graph parameter count."); + } + + // Create the arguments as an vector of XLA literals + std::vector<std::unique_ptr<Literal>> arg_literals; + std::vector<Literal*> arg_literals_ptrs; + for (int64 p = 0; p < computation->num_parameters(); p++) { + // Create the input literal for the parameter + HloInstruction* param = computation->parameter_instruction(p); + arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); + arg_literals_ptrs.push_back(arg_literals.back().get()); + + // Copy in the data from the stream_executor buffers + void* buffer = arg_literals.back().get()->MutableInternalData(); + memcpy(buffer, arguments[p].opaque(), + ShapeUtil::ByteSizeOf(param->shape())); + } + + // Execute the graph using the evaluator + HloEvaluator evaluator; + std::unique_ptr<Literal> output; + TF_ASSIGN_OR_RETURN(output, + evaluator.Evaluate(computation, arg_literals_ptrs)); + + // Copy the result into the return buffer + perftools::gputools::StreamExecutor* executor(stream->parent()); + sep::ExecutorExecutor* executorExecutor( + static_cast<sep::ExecutorExecutor*>(executor->implementation())); + + se::DeviceMemoryBase ret = + AllocateOutputBuffer(executorExecutor, *(output.get())); + + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + + { + tensorflow::mutex_lock lock(mutex_); + const double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + } + + return ret; +} + +StatusOr<std::unique_ptr<ShapedBuffer>> ExecutorExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile) { + return tensorflow::errors::Unimplemented( + "ExecuteOnStream is not yet supported on Executor."); +} + +StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) { + return tensorflow::errors::Unimplemented( + "ExecuteAsyncOnStream is not yet supported on Executor."); +} + +/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape)) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + + +} // namespace executorplugin +} // namespace xla diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h new file mode 100644 index 0000000000..ba3d4da21d --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executable.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ + +#include <cstddef> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace executorplugin { + +class ExecutorExecutable : public Executable { + public: + ExecutorExecutable(std::unique_ptr<HloModule> hlo_module); + ~ExecutorExecutable() override; + + StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> + arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> + arguments) override; + + static int64 ShapeSizeBytes(const Shape& shape); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc new file mode 100644 index 0000000000..e72c2711f7 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executor.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/compiler/xla/status_macros.h" + +#include <stdlib.h> +#include <string.h> + +namespace se = ::perftools::gputools; + +namespace perftools { +namespace gputools { +namespace executorplugin { + +host::HostStream *AsExecutorStream(Stream *stream) { + DCHECK(stream != nullptr); + return dynamic_cast<host::HostStream *>(stream->implementation()); +} + +ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) + : plugin_config_(plugin_config) {} + +ExecutorExecutor::~ExecutorExecutor() {} + +void *ExecutorExecutor::Allocate(uint64 size) { + void *buf = new char[size]; + return buf; +} + +void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 size_bytes) { + return parent + offset_bytes; +} + +void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) { + if (!mem->is_sub_buffer()) { + delete[] static_cast<char *>(mem->opaque()); + } +} + +bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, uint64 size) { + AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { + port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + }); + return true; +} + +bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64 size) { + AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { + port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + }); + return true; +} + +port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, + uint64 size) { + memcpy(dev_dst->opaque(), host_src, size); + return port::Status::OK(); +} + +port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &dev_src, + uint64 size) { + memcpy(host_dst, dev_src.opaque(), size); + return port::Status::OK(); +} + +bool ExecutorExecutor::HostCallback(Stream *stream, + std::function<void()> callback) { + AsExecutorStream(stream)->EnqueueTask(callback); + return true; +} + +bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { + AsExecutorStream(dependent)->EnqueueTask( + [other]() { other->BlockHostUntilDone(); }); + AsExecutorStream(dependent)->BlockUntilDone(); + return true; +} + +bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) { + dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream); + return true; +} + +bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) { + dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream); + return true; +} + +bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) { + AsExecutorStream(stream)->BlockUntilDone(); + return true; +} + +DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { + internal::DeviceDescriptionBuilder builder; + + builder.set_device_address_bits(64); + + builder.set_name("Executor"); + builder.set_device_vendor("VectorName"); + builder.set_platform_version("1.0"); + builder.set_driver_version("1.0"); + builder.set_runtime_version("1.0"); + builder.set_pci_bus_id("1"); + builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024); + builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9); + + auto built = builder.Build(); + return built.release(); +} + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/plugin/executor/executor.h new file mode 100644 index 0000000000..32fdb157e4 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executor.h @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declares the ExecutorExecutor class, which is a CPU-only implementation of +// the StreamExecutor interface. For now, this is used for testing and to +// examine the performance of host-based StreamExecutor code. +#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ + +#include "tensorflow/stream_executor/host/host_stream.h" +#include "tensorflow/stream_executor/host/host_timer.h" + +#include "tensorflow/compiler/xla/shape_util.h" + +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/rng.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" + +#include <list> +#include <mutex> + +namespace perftools { +namespace gputools { +namespace executorplugin { + +using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>; + +class ExecutorExecutor : public internal::StreamExecutorInterface { + public: + explicit ExecutorExecutor(const PluginConfig &plugin_config); + ~ExecutorExecutor() override; + + port::Status Init(int device_ordinal, DeviceOptions device_options) override { + return port::Status::OK(); + } + + bool GetKernel(const MultiKernelLoaderSpec &spec, + KernelBase *kernel) override { + return false; + } + bool Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const KernelBase &kernel, + const KernelArgsArrayBase &args) override { + return false; + } + + void *Allocate(uint64 size) override; + void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, + uint64 size_bytes) override; + void Deallocate(DeviceMemoryBase *mem) override; + + void *HostMemoryAllocate(uint64 size) override { return new char[size]; } + void HostMemoryDeallocate(void *mem) override { + delete[] static_cast<char *>(mem); + } + bool HostMemoryRegister(void *mem, uint64 size) override { return true; } + bool HostMemoryUnregister(void *mem) override { return true; } + + bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src, + uint64 size) override; + bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src, + uint64 size) override; + bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &host_src, + uint64 size) override { + return false; + } + + bool MemZero(Stream *stream, DeviceMemoryBase *location, + uint64 size) override { + return false; + } + bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, + uint64 size) override { + return false; + } + bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, + uint64 size) override { + return false; + } + + // No "synchronize all activity" implemented for this platform at the moment. + bool SynchronizeAllActivity() override { return false; } + bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { + return false; + } + + bool SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64 size) override { + return false; + } + + port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst, + const void *host_src, uint64 size) override; + port::Status SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &pop_src, + uint64 size) override; + port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &pop_src, + uint64 size) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + bool HostCallback(Stream *stream, std::function<void()> callback) override; + + port::Status AllocateEvent(Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status DeallocateEvent(Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status RecordEvent(Stream *stream, Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status WaitForEvent(Stream *stream, Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + Event::Status PollForEventStatus(Event *event) override { + return Event::Status::kError; + } + + bool AllocateStream(Stream *stream) override { return true; } + void DeallocateStream(Stream *stream) override {} + bool CreateStreamDependency(Stream *dependent, Stream *other) override; + + bool AllocateTimer(Timer *timer) override { return true; } + void DeallocateTimer(Timer *timer) override {} + bool StartTimer(Stream *stream, Timer *timer) override; + bool StopTimer(Stream *stream, Timer *timer) override; + + bool BlockHostUntilDone(Stream *stream) override; + + int PlatformDeviceCount() override { return 1; } + + bool DeviceMemoryUsage(int64 *free, int64 *total) const override { + return false; + } + + DeviceDescription *PopulateDeviceDescription() const override; + + port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { + return port::Status::OK(); + } + + bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { + return true; + } + + SharedMemoryConfig GetDeviceSharedMemoryConfig() override { + return SharedMemoryConfig::kDefault; + } + + port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override { + return port::Status{port::error::UNIMPLEMENTED, + "Shared memory not supported"}; + } + + std::unique_ptr<internal::EventInterface> CreateEventImplementation() + override { + return nullptr; + } + + std::unique_ptr<internal::KernelInterface> CreateKernelImplementation() + override { + return nullptr; + } + + std::unique_ptr<internal::StreamInterface> GetStreamImplementation() + override { + return std::unique_ptr<internal::StreamInterface>(new host::HostStream()); + } + + std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override { + return std::unique_ptr<internal::TimerInterface>(new host::HostTimer()); + } + + port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape, + Args args); + + private: + DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); + + port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer( + const xla::Shape &shape); + + const PluginConfig plugin_config_; +}; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ diff --git a/tensorflow/compiler/plugin/executor/platform.cc b/tensorflow/compiler/plugin/executor/platform.cc new file mode 100644 index 0000000000..2f339f04a7 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/platform.h" +#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/ptr_util.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/status_macros.h" +#include "tensorflow/stream_executor/lib/stringprintf.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + +namespace perftools { +namespace gputools { +namespace executorplugin { + +PLATFORM_DEFINE_ID(kExecutorPlatformId); + +ExecutorPlatform::ExecutorPlatform() : name_("Executor") {} + +ExecutorPlatform::~ExecutorPlatform() {} + +Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; } + +int ExecutorPlatform::VisibleDeviceCount() const { return 1; } + +const string& ExecutorPlatform::Name() const { return name_; } + +port::StatusOr<StreamExecutor*> ExecutorPlatform::ExecutorForDevice( + int ordinal) { + StreamExecutorConfig config; + config.ordinal = ordinal; + config.plugin_config = PluginConfig(); + config.device_options = DeviceOptions::Default(); + return GetExecutor(config); +} + +port::StatusOr<StreamExecutor*> +ExecutorPlatform::ExecutorForDeviceWithPluginConfig( + int device_ordinal, const PluginConfig& plugin_config) { + StreamExecutorConfig config; + config.ordinal = device_ordinal; + config.plugin_config = plugin_config; + config.device_options = DeviceOptions::Default(); + return GetExecutor(config); +} + +port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor( + const StreamExecutorConfig& config) { + mutex_lock lock(executors_mutex_); + + port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config); + if (status.ok()) { + return status.ValueOrDie(); + } + + port::StatusOr<std::unique_ptr<StreamExecutor>> executor = + GetUncachedExecutor(config); + if (!executor.ok()) { + return executor.status(); + } + + StreamExecutor* naked_executor = executor.ValueOrDie().get(); + SE_RETURN_IF_ERROR( + executor_cache_.Insert(config, executor.ConsumeValueOrDie())); + return naked_executor; +} + +port::StatusOr<std::unique_ptr<StreamExecutor>> +ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { + auto executor = port::MakeUnique<StreamExecutor>( + this, port::MakeUnique<ExecutorExecutor>(config.plugin_config)); + auto init_status = executor->Init(config.ordinal, config.device_options); + if (!init_status.ok()) { + return port::Status{ + port::error::INTERNAL, + port::Printf( + "failed initializing StreamExecutor for device ordinal %d: %s", + config.ordinal, init_status.ToString().c_str())}; + } + + return std::move(executor); +} + +void ExecutorPlatform::RegisterTraceListener( + std::unique_ptr<TraceListener> listener) { + LOG(FATAL) << "not yet implemented: register executor trace listener"; +} + +void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) { + LOG(FATAL) << "not yet implemented: unregister executor trace listener"; +} + +static void InitializeExecutorPlatform() { + std::unique_ptr<se::Platform> platform(new sep::ExecutorPlatform); + SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); +} + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform()); + +DECLARE_MODULE_INITIALIZER(multi_platform_manager); +// Note that module initialization sequencing is not supported in the +// open-source project, so this will be a no-op there. +REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager); diff --git a/tensorflow/compiler/plugin/executor/platform.h b/tensorflow/compiler/plugin/executor/platform.h new file mode 100644 index 0000000000..c252a589d4 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/trace_listener.h" + +namespace perftools { +namespace gputools { +namespace executorplugin { + +class ExecutorPlatform : public Platform { + public: + ExecutorPlatform(); + ~ExecutorPlatform() override; + + Platform::Id id() const override; + + // Device count is less clear-cut for CPUs than accelerators. This call + // currently returns the number of thread units in the host, as reported by + // base::NumCPUs(). + int VisibleDeviceCount() const override; + + const string& Name() const override; + + port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override; + + port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& config) override; + + port::StatusOr<StreamExecutor*> GetExecutor( + const StreamExecutorConfig& config) override; + + port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( + const StreamExecutorConfig& config) override; + + void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override; + + void UnregisterTraceListener(TraceListener* listener) override; + + private: + // This platform's name. + string name_; + + // mutex that guards the ordinal-to-executor map. + mutable mutex executors_mutex_; + + // Cache of created StreamExecutors. + ExecutorCache executor_cache_; + + SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform); +}; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ diff --git a/tensorflow/compiler/plugin/executor/platform_id.h b/tensorflow/compiler/plugin/executor/platform_id.h new file mode 100644 index 0000000000..8d2b29a3e4 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform_id.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ +#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ + +#include "tensorflow/stream_executor/platform.h" + +namespace perftools { +namespace gputools { +namespace executorplugin { + +extern const Platform::Id kExecutorPlatformId; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc new file mode 100644 index 0000000000..51c5deeea5 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc @@ -0,0 +1,187 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/transfer_manager.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +#include <string> +#include <utility> +#include <vector> + +namespace sep = ::perftools::gputools::executorplugin; + +namespace xla { +namespace executorplugin { + +ExecutorTransferManager::ExecutorTransferManager() {} + +se::Platform::Id ExecutorTransferManager::PlatformId() const { + return se::executorplugin::kExecutorPlatformId; +} + +Status ExecutorTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& device_shape, const Shape& literal_shape, Literal* literal) { + TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); + + // Tuples are a special case and contain one or more shapes inside of them to + // an arbitrary nesting depth. + if (device_shape.element_type() == TUPLE) { + *literal->mutable_shape() = literal_shape; + TF_ASSIGN_OR_RETURN( + std::vector<se::DeviceMemoryBase> element_buffers, + ShallowCopyTupleFromDevice(executor, source, device_shape)); + TF_RET_CHECK(element_buffers.size() == + ShapeUtil::TupleElementCount(device_shape)); + for (int64 i = 0; i < element_buffers.size(); ++i) { + const Shape& element_device_shape = device_shape.tuple_shapes(i); + const Shape& element_literal_shape = literal_shape.tuple_shapes(i); + Literal* element_literal = literal->add_tuple_literals(); + // Recursively call TransferFromDevice to copy over the data in the + // element array. + TF_RETURN_IF_ERROR(TransferLiteralFromDevice( + executor, element_buffers[i], element_device_shape, + element_literal_shape, element_literal)); + } + return Status::OK(); + } + + *literal->mutable_shape() = device_shape; + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, source, ShapeUtil::ByteSizeOf(device_shape), + literal->MutableInternalData())); + if (!ShapeUtil::Equal(literal_shape, device_shape)) { + literal->Swap( + literal->Relayout(literal_shape.layout()).get()); + } + TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); + return Status::OK(); +} + +StatusOr<std::vector<se::DeviceMemoryBase>> +ExecutorTransferManager::ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsTuple(shape)); + + std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape), + nullptr); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, + element_pointers.data()); + if (!copy_status.ok()) { + return AddStatus( + Status(static_cast<tensorflow::error::Code>(copy_status.code()), + copy_status.error_message()), + "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); + } + + // Create a DeviceMemoryBase from each void* pointer. + std::vector<se::DeviceMemoryBase> destination; + for (int i = 0; i < element_pointers.size(); ++i) { + if (element_pointers[i] == nullptr && + !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { + return FailedPrecondition("tuple contains nullptr at element %d", i); + } + int64 buffer_size = + ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*)); + destination.emplace_back(element_pointers[i], buffer_size); + } + return std::move(destination); +} + +Status ExecutorTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + se::DeviceMemoryBase* destination) { + const Shape& shape = literal.shape(); + + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector<void*> tuple_elements_on_device; + for (const Literal& tuple_element : literal.tuple_literals()) { + se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>( + GetByteSizeRequirement(tuple_element.shape())); + TF_RETURN_IF_ERROR( + TransferLiteralToDevice(executor, tuple_element, &allocation)); + tuple_elements_on_device.push_back(allocation.opaque()); + } + return TransferBufferToDevice( + executor, tuple_elements_on_device.size() * sizeof(void*), + tuple_elements_on_device.data(), destination); + } + + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + literal.InternalData(), + destination); +} + +Status ExecutorTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const Literal& literal) { + const Shape& shape = literal.shape(); + VLOG(1) << "transferring literal shape to infeed: " + << ShapeUtil::HumanString(shape); + + return Status::OK(); +} + +Status ExecutorTransferManager::TransferBufferToInfeed( + se::StreamExecutor* executor, int64 size, const void* source) { + return Unimplemented("Transfer to Infeed"); +} + +Status ExecutorTransferManager::TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + const Shape& shape = literal->shape(); + VLOG(1) << "transferring literal shape from outfeed: " + << ShapeUtil::HumanString(shape); + + return Status::OK(); +} + +Status ExecutorTransferManager::ResetDevices( + tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> + executors) { + return Unimplemented("Device reset not supported"); +} + +int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + +} // namespace executorplugin +} // namespace xla + +static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() { + return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId, + &CreateExecutorTransferManager); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h new file mode 100644 index 0000000000..7a42e5a2d7 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/transfer_manager.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ + +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +#include <vector> + +namespace se = ::perftools::gputools; + +namespace xla { +namespace executorplugin { + +class ExecutorTransferManager : public TransferManager { + public: + ExecutorTransferManager(); + + ~ExecutorTransferManager() override {} + + se::Platform::Id PlatformId() const override; + + StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) override; + + Status TransferLiteralFromDevice(se::StreamExecutor* executor, + const se::DeviceMemoryBase& source, + const Shape& device_shape, + const Shape& literal_shape, + Literal* literal) override; + + Status TransferLiteralToDevice(se::StreamExecutor* executor, + const Literal& literal, + se::DeviceMemoryBase* destination) override; + + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const Literal& literal) override; + + Status TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, const void* source) override; + + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; + + Status ResetDevices( + tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override; + + int64 GetByteSizeRequirement(const Shape& shape) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ |