diff options
author | DavidNorman <davidn@graphcore.ai> | 2017-05-19 11:34:37 +0100 |
---|---|---|
committer | Martin Wicke <martin.wicke@gmail.com> | 2017-06-13 18:41:07 -0700 |
commit | 781516d6a4aa9e04cd6d7a9e792ee9f14a5af8a5 (patch) | |
tree | 700e0e847eb3bd547be9901d0c1e97eca1f3756d | |
parent | dbbdf8f0a47c7d4cecb692f05bca6d359f6f0891 (diff) |
Formatting code with clang-format -style=google
-rw-r--r-- | tensorflow/compiler/plugin/example/compiler.cc | 36 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/compiler.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/device.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/executable.cc | 38 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/executable.h | 18 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/executor.cc | 28 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/executor.h | 41 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/platform.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/transfer_manager.cc | 80 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/example/transfer_manager.h | 52 |
10 files changed, 150 insertions, 170 deletions
diff --git a/tensorflow/compiler/plugin/example/compiler.cc b/tensorflow/compiler/plugin/example/compiler.cc index 127bd8e660..562492be51 100644 --- a/tensorflow/compiler/plugin/example/compiler.cc +++ b/tensorflow/compiler/plugin/example/compiler.cc @@ -31,8 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/core/lib/core/errors.h" @@ -56,8 +56,8 @@ Status ExampleCompiler::RunHloOptimization(HloModule* hlo_module, pipeline.AddPass<HloSubcomputationUnification>(); pipeline.AddPass<HloCSE>(false); - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(false, - [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( + false, [](const Shape&, const Shape&) { return false; }); pipeline.AddPass<ReshapeMover>(); pipeline.AddPass<HloConstantFolding>(); pipeline.AddPass<HloCSE>(true); @@ -68,47 +68,45 @@ Status ExampleCompiler::RunHloOptimization(HloModule* hlo_module, } StatusOr<std::unique_ptr<Executable>> ExampleCompiler::Compile( - std::unique_ptr<HloModule> hlo_module, - std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo, - se::StreamExecutor* stream_exec) { + std::unique_ptr<HloModule> hlo_module, + std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo, + se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Generate graph " << hlo_module->name(); TF_RETURN_IF_ERROR( - RunHloOptimization(hlo_module.get(), module_config.get(), dump_hlo)); + RunHloOptimization(hlo_module.get(), module_config.get(), dump_hlo)); // Typically you would visit the HLO graph, building up a compiled equivalent - // In this case we are using an Hlo evaluator at execution time, so we don't\ + // In this case we are using an Hlo evaluator at execution time, so we don't // need to compile anything // Create executable from only the Hlo module std::unique_ptr<Executable> executable; executable.reset( - new ExampleExecutable(std::move(hlo_module), - std::move(module_config))); + new ExampleExecutable(std::move(hlo_module), std::move(module_config))); return std::move(executable); } StatusOr<std::vector<std::unique_ptr<Executable>>> ExampleCompiler::Compile( - std::vector<std::unique_ptr<HloModule>> hlo_modules, -std::vector<std::unique_ptr<HloModuleConfig>> module_configs, - HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) { - + std::vector<std::unique_ptr<HloModule>> hlo_modules, + std::vector<std::unique_ptr<HloModuleConfig>> module_configs, + HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) { return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Example."); + "Compilation of multiple HLO modules is not supported on Example."); } StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> ExampleCompiler::CompileAheadOfTime( - std::vector<std::unique_ptr<HloModule>> hlo_modules, -std::vector<std::unique_ptr<HloModuleConfig>> module_configs, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + std::vector<std::unique_ptr<HloModule>> hlo_modules, + std::vector<std::unique_ptr<HloModuleConfig>> module_configs, + HloDumper dump_hlo, const AotCompilationOptions& aot_options) { TF_RET_CHECK(hlo_modules.size() == module_configs.size()); return tensorflow::errors::InvalidArgument( - "AOT compilation not supported on Example"); + "AOT compilation not supported on Example"); } int64 ExampleCompiler::ShapeSizeBytes(const Shape& shape) const { diff --git a/tensorflow/compiler/plugin/example/compiler.h b/tensorflow/compiler/plugin/example/compiler.h index 75d210a89b..f45aa7680e 100644 --- a/tensorflow/compiler/plugin/example/compiler.h +++ b/tensorflow/compiler/plugin/example/compiler.h @@ -18,7 +18,6 @@ limitations under the License. #include <memory> - #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -29,7 +28,6 @@ limitations under the License. namespace xla { namespace exampleplugin { - class ExampleCompiler : public Compiler { public: ExampleCompiler() {} @@ -57,10 +55,8 @@ class ExampleCompiler : public Compiler { perftools::gputools::Platform::Id PlatformId() const override; private: - Status RunHloOptimization(HloModule* hlo_module, - HloModuleConfig* module_config, - HloDumper dump_hlo); + HloModuleConfig* module_config, HloDumper dump_hlo); TF_DISALLOW_COPY_AND_ASSIGN(ExampleCompiler); }; diff --git a/tensorflow/compiler/plugin/example/device.cc b/tensorflow/compiler/plugin/example/device.cc index 6b2d893316..306440c5c3 100644 --- a/tensorflow/compiler/plugin/example/device.cc +++ b/tensorflow/compiler/plugin/example/device.cc @@ -16,9 +16,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" namespace tensorflow { @@ -26,8 +26,8 @@ namespace tensorflow { const char* const DEVICE_XLA_EXA = "XLA_EXA"; const char* const DEVICE_EXA_XLA_JIT = "XLA_EXA_JIT"; -constexpr std::array<DataType, 5> kExaAllTypes = - {{DT_INT32, DT_FLOAT, DT_BOOL}}; +constexpr std::array<DataType, 5> kExaAllTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL}}; class XlaExaDeviceFactory : public DeviceFactory { public: @@ -54,9 +54,7 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXA, XlaExaDeviceFactory, 210); // Kernel registrations -static bool OpFilter(KernelDef* kdef) { - return true; -} +static bool OpFilter(KernelDef* kdef) { return true; } REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXA, XlaDeviceLaunchOp, kExaAllTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXA, kExaAllTypes); diff --git a/tensorflow/compiler/plugin/example/executable.cc b/tensorflow/compiler/plugin/example/executable.cc index 2178c2fbc3..6875323722 100644 --- a/tensorflow/compiler/plugin/example/executable.cc +++ b/tensorflow/compiler/plugin/example/executable.cc @@ -28,16 +28,14 @@ namespace xla { namespace exampleplugin { ExampleExecutable::ExampleExecutable( - std::unique_ptr<HloModule> hlo_module, - std::unique_ptr<HloModuleConfig> module_config) - : Executable(std::move(hlo_module), - std::move(module_config)) { -} + std::unique_ptr<HloModule> hlo_module, + std::unique_ptr<HloModuleConfig> module_config) + : Executable(std::move(hlo_module), std::move(module_config)) {} ExampleExecutable::~ExampleExecutable() {} -static se::DeviceMemoryBase -AllocateSingleOutput(sep::ExampleExecutor* executor, Literal* literal) { +static se::DeviceMemoryBase AllocateSingleOutput(sep::ExampleExecutor* executor, + Literal* literal) { int64 size(xla::ShapeUtil::ByteSizeOf(literal->shape())); void* buf = executor->Allocate(size); const void* src = LiteralUtil::InternalData(*literal); @@ -45,17 +43,17 @@ AllocateSingleOutput(sep::ExampleExecutor* executor, Literal* literal) { return se::DeviceMemoryBase(buf, size); } -static se::DeviceMemoryBase -AllocateOutputBuffer(sep::ExampleExecutor* executor, Literal* literal) { +static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExampleExecutor* executor, + Literal* literal) { const Shape& shape = literal->shape(); if (shape.element_type() != xla::TUPLE) { return AllocateSingleOutput(executor, literal); } else { int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); void** buf = reinterpret_cast<void**>(executor->Allocate(size)); - for (int64 n=0; n<xla::ShapeUtil::TupleElementCount(shape); n++) { + for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { se::DeviceMemoryBase out = - AllocateSingleOutput(executor, literal->mutable_tuple_literals(n)); + AllocateSingleOutput(executor, literal->mutable_tuple_literals(n)); *buf++ = out.opaque(); } @@ -63,8 +61,7 @@ AllocateOutputBuffer(sep::ExampleExecutor* executor, Literal* literal) { } } -StatusOr<se::DeviceMemoryBase> -ExampleExecutable::ExecuteOnStream( +StatusOr<se::DeviceMemoryBase> ExampleExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments, HloExecutionProfile* hlo_execution_profile) { @@ -82,7 +79,7 @@ ExampleExecutable::ExecuteOnStream( HloComputation* computation = module().entry_computation(); if (computation->num_parameters() != arguments.size()) { return tensorflow::errors::Internal( - "Mismatch between argument count and graph parameter count."); + "Mismatch between argument count and graph parameter count."); } // Create the arguments as an vector of XLA literals @@ -109,11 +106,10 @@ ExampleExecutable::ExecuteOnStream( // Copy the result into the return buffer perftools::gputools::StreamExecutor* executor(stream->parent()); sep::ExampleExecutor* exampleExecutor( - static_cast<sep::ExampleExecutor*>(executor->implementation())); + static_cast<sep::ExampleExecutor*>(executor->implementation())); se::DeviceMemoryBase ret = - AllocateOutputBuffer(exampleExecutor, output.get()); - + AllocateOutputBuffer(exampleExecutor, output.get()); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -131,17 +127,15 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ExampleExecutable::ExecuteOnStream( tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, HloExecutionProfile* hlo_execution_profile) { return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Example."); + "ExecuteOnStream is not yet supported on Example."); } -StatusOr<se::DeviceMemoryBase> -ExampleExecutable::ExecuteAsyncOnStream( +StatusOr<se::DeviceMemoryBase> ExampleExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) { return tensorflow::errors::Unimplemented( - "ExecuteAsyncOnStream is not yet supported on Example."); + "ExecuteAsyncOnStream is not yet supported on Example."); } } // namespace exampleplugin } // namespace xla - diff --git a/tensorflow/compiler/plugin/example/executable.h b/tensorflow/compiler/plugin/example/executable.h index 0d3dc3f682..99d2da77dd 100644 --- a/tensorflow/compiler/plugin/example/executable.h +++ b/tensorflow/compiler/plugin/example/executable.h @@ -41,25 +41,23 @@ class ExampleExecutable : public Executable { std::unique_ptr<HloModuleConfig> module_config); ~ExampleExecutable() override; - StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override; StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, - HloExecutionProfile* hlo_execution_profile) override; + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile) override; StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments) override; private: - TF_DISALLOW_COPY_AND_ASSIGN(ExampleExecutable); }; diff --git a/tensorflow/compiler/plugin/example/executor.cc b/tensorflow/compiler/plugin/example/executor.cc index bc56fad3ec..79beccf352 100644 --- a/tensorflow/compiler/plugin/example/executor.cc +++ b/tensorflow/compiler/plugin/example/executor.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" -#include <string.h> #include <stdlib.h> +#include <string.h> namespace se = ::perftools::gputools; @@ -33,18 +33,18 @@ host::HostStream *AsExampleStream(Stream *stream) { } ExampleExecutor::ExampleExecutor(const PluginConfig &plugin_config) - : plugin_config_(plugin_config) { -} + : plugin_config_(plugin_config) {} ExampleExecutor::~ExampleExecutor() {} void *ExampleExecutor::Allocate(uint64 size) { - void* buf = new char[size]; + void *buf = new char[size]; return buf; } void *ExampleExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, - uint64 offset_bytes, uint64 size_bytes) { + uint64 offset_bytes, + uint64 size_bytes) { return parent + offset_bytes; } @@ -55,18 +55,18 @@ void ExampleExecutor::Deallocate(DeviceMemoryBase *mem) { } bool ExampleExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, uint64 size) { - AsExampleStream(stream)->EnqueueTask( - [this, host_dst, dev_src, size]() { - port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); + const DeviceMemoryBase &dev_src, uint64 size) { + AsExampleStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { + port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + }); return true; } bool ExampleExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64 size) { - AsExampleStream(stream)->EnqueueTask( - [this, dev_dst, host_src, size]() { - port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); + const void *host_src, uint64 size) { + AsExampleStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { + port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + }); return true; } @@ -85,7 +85,7 @@ port::Status ExampleExecutor::SynchronousMemcpy(void *host_dst, } bool ExampleExecutor::HostCallback(Stream *stream, - std::function<void()> callback) { + std::function<void()> callback) { AsExampleStream(stream)->EnqueueTask(callback); return true; } diff --git a/tensorflow/compiler/plugin/example/executor.h b/tensorflow/compiler/plugin/example/executor.h index 0c2bce1604..e0b7d27bb4 100644 --- a/tensorflow/compiler/plugin/example/executor.h +++ b/tensorflow/compiler/plugin/example/executor.h @@ -78,14 +78,22 @@ class ExampleExecutor : public internal::StreamExecutorInterface { uint64 size) override; bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, const DeviceMemoryBase &host_src, - uint64 size) override { return false; } + uint64 size) override { + return false; + } bool MemZero(Stream *stream, DeviceMemoryBase *location, - uint64 size) override { return false; } + uint64 size) override { + return false; + } bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, - uint64 size) override { return false; } + uint64 size) override { + return false; + } bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, - uint64 size) override { return false; } + uint64 size) override { + return false; + } // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return false; } @@ -94,7 +102,9 @@ class ExampleExecutor : public internal::StreamExecutorInterface { } bool SynchronousMemSet(DeviceMemoryBase *location, int value, - uint64 size) override { return false; } + uint64 size) override { + return false; + } port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst, const void *host_src, uint64 size) override; @@ -166,10 +176,14 @@ class ExampleExecutor : public internal::StreamExecutorInterface { } std::unique_ptr<internal::EventInterface> CreateEventImplementation() - override { return nullptr; } + override { + return nullptr; + } std::unique_ptr<internal::KernelInterface> CreateKernelImplementation() - override { return nullptr; } + override { + return nullptr; + } std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override { @@ -180,19 +194,16 @@ class ExampleExecutor : public internal::StreamExecutorInterface { return std::unique_ptr<internal::TimerInterface>(new host::HostTimer()); } - port::StatusOr<DeviceMemoryBase> - ExecuteGraph(const xla::Shape& shape, Args args); + port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape, + Args args); private: + DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); - DeviceMemoryBase - AllocateSingleOutput(const xla::Shape& shape); - - port::StatusOr<DeviceMemoryBase> - AllocateOutputBuffer(const xla::Shape& shape); + port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer( + const xla::Shape &shape); const PluginConfig plugin_config_; - }; } // namespace exampleplugin diff --git a/tensorflow/compiler/plugin/example/platform.cc b/tensorflow/compiler/plugin/example/platform.cc index 477ee3300b..00edb62364 100644 --- a/tensorflow/compiler/plugin/example/platform.cc +++ b/tensorflow/compiler/plugin/example/platform.cc @@ -39,14 +39,12 @@ ExamplePlatform::~ExamplePlatform() {} Platform::Id ExamplePlatform::id() const { return kExamplePlatformId; } -int ExamplePlatform::VisibleDeviceCount() const { - return 1; -} +int ExamplePlatform::VisibleDeviceCount() const { return 1; } const string& ExamplePlatform::Name() const { return name_; } -port::StatusOr<StreamExecutor*> -ExamplePlatform::ExecutorForDevice(int ordinal) { +port::StatusOr<StreamExecutor*> ExamplePlatform::ExecutorForDevice( + int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; config.plugin_config = PluginConfig(); @@ -119,8 +117,7 @@ static void InitializeExamplePlatform() { } // namespace gputools } // namespace perftools -REGISTER_MODULE_INITIALIZER( - example_platform, sep::InitializeExamplePlatform()); +REGISTER_MODULE_INITIALIZER(example_platform, sep::InitializeExamplePlatform()); DECLARE_MODULE_INITIALIZER(multi_platform_manager); // Note that module initialization sequencing is not supported in the diff --git a/tensorflow/compiler/plugin/example/transfer_manager.cc b/tensorflow/compiler/plugin/example/transfer_manager.cc index 7e6f03571a..cef8a49e19 100644 --- a/tensorflow/compiler/plugin/example/transfer_manager.cc +++ b/tensorflow/compiler/plugin/example/transfer_manager.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/plugin/example/platform_id.h" #include "tensorflow/compiler/plugin/example/transfer_manager.h" +#include "tensorflow/compiler/plugin/example/platform_id.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,16 +34,15 @@ limitations under the License. namespace xla { namespace exampleplugin { -ExampleTransferManager::ExampleTransferManager() { -} +ExampleTransferManager::ExampleTransferManager() {} se::Platform::Id ExampleTransferManager::PlatformId() const { return se::exampleplugin::kExamplePlatformId; } Status ExampleTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& device_shape, const Shape& literal_shape, Literal* literal) { TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); // Tuples are a special case and contain one or more shapes inside of them to @@ -51,8 +50,8 @@ Status ExampleTransferManager::TransferLiteralFromDevice( if (device_shape.element_type() == TUPLE) { *literal->mutable_shape() = literal_shape; TF_ASSIGN_OR_RETURN( - std::vector<se::DeviceMemoryBase> element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); + std::vector<se::DeviceMemoryBase> element_buffers, + ShallowCopyTupleFromDevice(executor, source, device_shape)); TF_RET_CHECK(element_buffers.size() == ShapeUtil::TupleElementCount(device_shape)); for (int64 i = 0; i < element_buffers.size(); ++i) { @@ -62,8 +61,8 @@ Status ExampleTransferManager::TransferLiteralFromDevice( // Recursively call TransferFromDevice to copy over the data in the // element array. TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], element_device_shape, - element_literal_shape, element_literal)); + executor, element_buffers[i], element_device_shape, + element_literal_shape, element_literal)); } return Status::OK(); } @@ -71,11 +70,11 @@ Status ExampleTransferManager::TransferLiteralFromDevice( *literal->mutable_shape() = device_shape; LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, ShapeUtil::ByteSizeOf(device_shape), - LiteralUtil::MutableInternalData(literal))); + executor, source, ShapeUtil::ByteSizeOf(device_shape), + LiteralUtil::MutableInternalData(literal))); if (!ShapeUtil::Equal(literal_shape, device_shape)) { literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -83,21 +82,20 @@ Status ExampleTransferManager::TransferLiteralFromDevice( StatusOr<std::vector<se::DeviceMemoryBase>> ExampleTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) { TF_RET_CHECK(ShapeUtil::IsTuple(shape)); std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape), nullptr); - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*)); auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, element_pointers.data()); if (!copy_status.ok()) { return AddStatus( - Status(static_cast<tensorflow::error::Code>(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); + Status(static_cast<tensorflow::error::Code>(copy_status.code()), + copy_status.error_message()), + "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); } // Create a DeviceMemoryBase from each void* pointer. @@ -107,62 +105,58 @@ ExampleTransferManager::ShallowCopyTupleFromDevice( !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %d", i); } - int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), - sizeof(void*)); + int64 buffer_size = + ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*)); destination.emplace_back(element_pointers[i], buffer_size); } return std::move(destination); } Status ExampleTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { + se::StreamExecutor* executor, const Literal& literal, + se::DeviceMemoryBase* destination) { const Shape& shape = literal.shape(); if (ShapeUtil::IsTuple(literal.shape())) { std::vector<void*> tuple_elements_on_device; for (const Literal& tuple_element : literal.tuple_literals()) { se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>( - GetByteSizeRequirement(tuple_element.shape())); + GetByteSizeRequirement(tuple_element.shape())); TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); + TransferLiteralToDevice(executor, tuple_element, &allocation)); tuple_elements_on_device.push_back(allocation.opaque()); } return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); + executor, tuple_elements_on_device.size() * sizeof(void*), + tuple_elements_on_device.data(), destination); } - return TransferBufferToDevice( - executor, GetByteSizeRequirement(shape), - LiteralUtil::InternalData(literal), destination); + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + LiteralUtil::InternalData(literal), + destination); } -Status -ExampleTransferManager::TransferLiteralToInfeed(se::StreamExecutor *executor, - const Literal &literal) { - const Shape &shape = literal.shape(); +Status ExampleTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const Literal& literal) { + const Shape& shape = literal.shape(); VLOG(1) << "transferring literal shape to infeed: " - << ShapeUtil::HumanString(shape); + << ShapeUtil::HumanString(shape); return Status::OK(); } -Status -ExampleTransferManager::TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) { - const Shape &shape = literal->shape(); +Status ExampleTransferManager::TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + const Shape& shape = literal->shape(); VLOG(1) << "transferring literal shape from outfeed: " << ShapeUtil::HumanString(shape); return Status::OK(); } - Status ExampleTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> + tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors) { return Unimplemented("Device reset not supported"); } diff --git a/tensorflow/compiler/plugin/example/transfer_manager.h b/tensorflow/compiler/plugin/example/transfer_manager.h index ce7d941221..0f7b81e613 100644 --- a/tensorflow/compiler/plugin/example/transfer_manager.h +++ b/tensorflow/compiler/plugin/example/transfer_manager.h @@ -35,46 +35,40 @@ namespace xla { namespace exampleplugin { class ExampleTransferManager : public TransferManager { -public: + public: ExampleTransferManager(); ~ExampleTransferManager() override {} se::Platform::Id PlatformId() const override; - StatusOr<std::vector<se::DeviceMemoryBase>> - ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, - const se::DeviceMemoryBase& source, - const Shape& shape) override; - - Status TransferLiteralFromDevice( - se::StreamExecutor* executor, - const se::DeviceMemoryBase& source, - const Shape& device_shape, - const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice( - se::StreamExecutor* executor, - const Literal& literal, - se::DeviceMemoryBase* destination) override; - - Status - TransferLiteralToInfeed(se::StreamExecutor *executor, - const Literal &literal) override; - - Status TransferLiteralFromOutfeed( - se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) override; + StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) override; + + Status TransferLiteralFromDevice(se::StreamExecutor* executor, + const se::DeviceMemoryBase& source, + const Shape& device_shape, + const Shape& literal_shape, + Literal* literal) override; + + Status TransferLiteralToDevice(se::StreamExecutor* executor, + const Literal& literal, + se::DeviceMemoryBase* destination) override; + + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const Literal& literal) override; + + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; Status ResetDevices( - tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override; + tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override; int64 GetByteSizeRequirement(const Shape& shape) override; -private: + private: TF_DISALLOW_COPY_AND_ASSIGN(ExampleTransferManager); }; |