aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar DavidNorman <davidn@graphcore.ai>2017-05-19 11:34:37 +0100
committerGravatar Martin Wicke <martin.wicke@gmail.com>2017-06-13 18:41:07 -0700
commit781516d6a4aa9e04cd6d7a9e792ee9f14a5af8a5 (patch)
tree700e0e847eb3bd547be9901d0c1e97eca1f3756d
parentdbbdf8f0a47c7d4cecb692f05bca6d359f6f0891 (diff)
Formatting code with clang-format -style=google
-rw-r--r--tensorflow/compiler/plugin/example/compiler.cc36
-rw-r--r--tensorflow/compiler/plugin/example/compiler.h6
-rw-r--r--tensorflow/compiler/plugin/example/device.cc10
-rw-r--r--tensorflow/compiler/plugin/example/executable.cc38
-rw-r--r--tensorflow/compiler/plugin/example/executable.h18
-rw-r--r--tensorflow/compiler/plugin/example/executor.cc28
-rw-r--r--tensorflow/compiler/plugin/example/executor.h41
-rw-r--r--tensorflow/compiler/plugin/example/platform.cc11
-rw-r--r--tensorflow/compiler/plugin/example/transfer_manager.cc80
-rw-r--r--tensorflow/compiler/plugin/example/transfer_manager.h52
10 files changed, 150 insertions, 170 deletions
diff --git a/tensorflow/compiler/plugin/example/compiler.cc b/tensorflow/compiler/plugin/example/compiler.cc
index 127bd8e660..562492be51 100644
--- a/tensorflow/compiler/plugin/example/compiler.cc
+++ b/tensorflow/compiler/plugin/example/compiler.cc
@@ -31,8 +31,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -56,8 +56,8 @@ Status ExampleCompiler::RunHloOptimization(HloModule* hlo_module,
pipeline.AddPass<HloSubcomputationUnification>();
pipeline.AddPass<HloCSE>(false);
- pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(false,
- [](const Shape&, const Shape&) { return false; });
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ false, [](const Shape&, const Shape&) { return false; });
pipeline.AddPass<ReshapeMover>();
pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<HloCSE>(true);
@@ -68,47 +68,45 @@ Status ExampleCompiler::RunHloOptimization(HloModule* hlo_module,
}
StatusOr<std::unique_ptr<Executable>> ExampleCompiler::Compile(
- std::unique_ptr<HloModule> hlo_module,
- std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
- se::StreamExecutor* stream_exec) {
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ se::StreamExecutor* stream_exec) {
TF_RET_CHECK(stream_exec != nullptr);
VLOG(1) << "Generate graph " << hlo_module->name();
TF_RETURN_IF_ERROR(
- RunHloOptimization(hlo_module.get(), module_config.get(), dump_hlo));
+ RunHloOptimization(hlo_module.get(), module_config.get(), dump_hlo));
// Typically you would visit the HLO graph, building up a compiled equivalent
- // In this case we are using an Hlo evaluator at execution time, so we don't\
+ // In this case we are using an Hlo evaluator at execution time, so we don't
// need to compile anything
// Create executable from only the Hlo module
std::unique_ptr<Executable> executable;
executable.reset(
- new ExampleExecutable(std::move(hlo_module),
- std::move(module_config)));
+ new ExampleExecutable(std::move(hlo_module), std::move(module_config)));
return std::move(executable);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> ExampleCompiler::Compile(
- std::vector<std::unique_ptr<HloModule>> hlo_modules,
-std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
- HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
-
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
return tensorflow::errors::Unimplemented(
- "Compilation of multiple HLO modules is not supported on Example.");
+ "Compilation of multiple HLO modules is not supported on Example.");
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
ExampleCompiler::CompileAheadOfTime(
- std::vector<std::unique_ptr<HloModule>> hlo_modules,
-std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
- HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
TF_RET_CHECK(hlo_modules.size() == module_configs.size());
return tensorflow::errors::InvalidArgument(
- "AOT compilation not supported on Example");
+ "AOT compilation not supported on Example");
}
int64 ExampleCompiler::ShapeSizeBytes(const Shape& shape) const {
diff --git a/tensorflow/compiler/plugin/example/compiler.h b/tensorflow/compiler/plugin/example/compiler.h
index 75d210a89b..f45aa7680e 100644
--- a/tensorflow/compiler/plugin/example/compiler.h
+++ b/tensorflow/compiler/plugin/example/compiler.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
-
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -29,7 +28,6 @@ limitations under the License.
namespace xla {
namespace exampleplugin {
-
class ExampleCompiler : public Compiler {
public:
ExampleCompiler() {}
@@ -57,10 +55,8 @@ class ExampleCompiler : public Compiler {
perftools::gputools::Platform::Id PlatformId() const override;
private:
-
Status RunHloOptimization(HloModule* hlo_module,
- HloModuleConfig* module_config,
- HloDumper dump_hlo);
+ HloModuleConfig* module_config, HloDumper dump_hlo);
TF_DISALLOW_COPY_AND_ASSIGN(ExampleCompiler);
};
diff --git a/tensorflow/compiler/plugin/example/device.cc b/tensorflow/compiler/plugin/example/device.cc
index 6b2d893316..306440c5c3 100644
--- a/tensorflow/compiler/plugin/example/device.cc
+++ b/tensorflow/compiler/plugin/example/device.cc
@@ -16,9 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
-#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
@@ -26,8 +26,8 @@ namespace tensorflow {
const char* const DEVICE_XLA_EXA = "XLA_EXA";
const char* const DEVICE_EXA_XLA_JIT = "XLA_EXA_JIT";
-constexpr std::array<DataType, 5> kExaAllTypes =
- {{DT_INT32, DT_FLOAT, DT_BOOL}};
+constexpr std::array<DataType, 5> kExaAllTypes = {
+ {DT_INT32, DT_FLOAT, DT_BOOL}};
class XlaExaDeviceFactory : public DeviceFactory {
public:
@@ -54,9 +54,7 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXA, XlaExaDeviceFactory, 210);
// Kernel registrations
-static bool OpFilter(KernelDef* kdef) {
- return true;
-}
+static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXA, XlaDeviceLaunchOp, kExaAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXA, kExaAllTypes);
diff --git a/tensorflow/compiler/plugin/example/executable.cc b/tensorflow/compiler/plugin/example/executable.cc
index 2178c2fbc3..6875323722 100644
--- a/tensorflow/compiler/plugin/example/executable.cc
+++ b/tensorflow/compiler/plugin/example/executable.cc
@@ -28,16 +28,14 @@ namespace xla {
namespace exampleplugin {
ExampleExecutable::ExampleExecutable(
- std::unique_ptr<HloModule> hlo_module,
- std::unique_ptr<HloModuleConfig> module_config)
- : Executable(std::move(hlo_module),
- std::move(module_config)) {
-}
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config)
+ : Executable(std::move(hlo_module), std::move(module_config)) {}
ExampleExecutable::~ExampleExecutable() {}
-static se::DeviceMemoryBase
-AllocateSingleOutput(sep::ExampleExecutor* executor, Literal* literal) {
+static se::DeviceMemoryBase AllocateSingleOutput(sep::ExampleExecutor* executor,
+ Literal* literal) {
int64 size(xla::ShapeUtil::ByteSizeOf(literal->shape()));
void* buf = executor->Allocate(size);
const void* src = LiteralUtil::InternalData(*literal);
@@ -45,17 +43,17 @@ AllocateSingleOutput(sep::ExampleExecutor* executor, Literal* literal) {
return se::DeviceMemoryBase(buf, size);
}
-static se::DeviceMemoryBase
-AllocateOutputBuffer(sep::ExampleExecutor* executor, Literal* literal) {
+static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExampleExecutor* executor,
+ Literal* literal) {
const Shape& shape = literal->shape();
if (shape.element_type() != xla::TUPLE) {
return AllocateSingleOutput(executor, literal);
} else {
int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*)));
void** buf = reinterpret_cast<void**>(executor->Allocate(size));
- for (int64 n=0; n<xla::ShapeUtil::TupleElementCount(shape); n++) {
+ for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) {
se::DeviceMemoryBase out =
- AllocateSingleOutput(executor, literal->mutable_tuple_literals(n));
+ AllocateSingleOutput(executor, literal->mutable_tuple_literals(n));
*buf++ = out.opaque();
}
@@ -63,8 +61,7 @@ AllocateOutputBuffer(sep::ExampleExecutor* executor, Literal* literal) {
}
}
-StatusOr<se::DeviceMemoryBase>
-ExampleExecutable::ExecuteOnStream(
+StatusOr<se::DeviceMemoryBase> ExampleExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
HloExecutionProfile* hlo_execution_profile) {
@@ -82,7 +79,7 @@ ExampleExecutable::ExecuteOnStream(
HloComputation* computation = module().entry_computation();
if (computation->num_parameters() != arguments.size()) {
return tensorflow::errors::Internal(
- "Mismatch between argument count and graph parameter count.");
+ "Mismatch between argument count and graph parameter count.");
}
// Create the arguments as an vector of XLA literals
@@ -109,11 +106,10 @@ ExampleExecutable::ExecuteOnStream(
// Copy the result into the return buffer
perftools::gputools::StreamExecutor* executor(stream->parent());
sep::ExampleExecutor* exampleExecutor(
- static_cast<sep::ExampleExecutor*>(executor->implementation()));
+ static_cast<sep::ExampleExecutor*>(executor->implementation()));
se::DeviceMemoryBase ret =
- AllocateOutputBuffer(exampleExecutor, output.get());
-
+ AllocateOutputBuffer(exampleExecutor, output.get());
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -131,17 +127,15 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ExampleExecutable::ExecuteOnStream(
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
return tensorflow::errors::Unimplemented(
- "ExecuteOnStream is not yet supported on Example.");
+ "ExecuteOnStream is not yet supported on Example.");
}
-StatusOr<se::DeviceMemoryBase>
-ExampleExecutable::ExecuteAsyncOnStream(
+StatusOr<se::DeviceMemoryBase> ExampleExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
return tensorflow::errors::Unimplemented(
- "ExecuteAsyncOnStream is not yet supported on Example.");
+ "ExecuteAsyncOnStream is not yet supported on Example.");
}
} // namespace exampleplugin
} // namespace xla
-
diff --git a/tensorflow/compiler/plugin/example/executable.h b/tensorflow/compiler/plugin/example/executable.h
index 0d3dc3f682..99d2da77dd 100644
--- a/tensorflow/compiler/plugin/example/executable.h
+++ b/tensorflow/compiler/plugin/example/executable.h
@@ -41,25 +41,23 @@ class ExampleExecutable : public Executable {
std::unique_ptr<HloModuleConfig> module_config);
~ExampleExecutable() override;
-
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
- HloExecutionProfile* hlo_execution_profile) override;
+ HloExecutionProfile* hlo_execution_profile) override;
StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- HloExecutionProfile* hlo_execution_profile) override;
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments) override;
private:
-
TF_DISALLOW_COPY_AND_ASSIGN(ExampleExecutable);
};
diff --git a/tensorflow/compiler/plugin/example/executor.cc b/tensorflow/compiler/plugin/example/executor.cc
index bc56fad3ec..79beccf352 100644
--- a/tensorflow/compiler/plugin/example/executor.cc
+++ b/tensorflow/compiler/plugin/example/executor.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
-#include <string.h>
#include <stdlib.h>
+#include <string.h>
namespace se = ::perftools::gputools;
@@ -33,18 +33,18 @@ host::HostStream *AsExampleStream(Stream *stream) {
}
ExampleExecutor::ExampleExecutor(const PluginConfig &plugin_config)
- : plugin_config_(plugin_config) {
-}
+ : plugin_config_(plugin_config) {}
ExampleExecutor::~ExampleExecutor() {}
void *ExampleExecutor::Allocate(uint64 size) {
- void* buf = new char[size];
+ void *buf = new char[size];
return buf;
}
void *ExampleExecutor::AllocateSubBuffer(DeviceMemoryBase *parent,
- uint64 offset_bytes, uint64 size_bytes) {
+ uint64 offset_bytes,
+ uint64 size_bytes) {
return parent + offset_bytes;
}
@@ -55,18 +55,18 @@ void ExampleExecutor::Deallocate(DeviceMemoryBase *mem) {
}
bool ExampleExecutor::Memcpy(Stream *stream, void *host_dst,
- const DeviceMemoryBase &dev_src, uint64 size) {
- AsExampleStream(stream)->EnqueueTask(
- [this, host_dst, dev_src, size]() {
- port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); });
+ const DeviceMemoryBase &dev_src, uint64 size) {
+ AsExampleStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
+ port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
+ });
return true;
}
bool ExampleExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
- const void *host_src, uint64 size) {
- AsExampleStream(stream)->EnqueueTask(
- [this, dev_dst, host_src, size]() {
- port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); });
+ const void *host_src, uint64 size) {
+ AsExampleStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
+ port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
+ });
return true;
}
@@ -85,7 +85,7 @@ port::Status ExampleExecutor::SynchronousMemcpy(void *host_dst,
}
bool ExampleExecutor::HostCallback(Stream *stream,
- std::function<void()> callback) {
+ std::function<void()> callback) {
AsExampleStream(stream)->EnqueueTask(callback);
return true;
}
diff --git a/tensorflow/compiler/plugin/example/executor.h b/tensorflow/compiler/plugin/example/executor.h
index 0c2bce1604..e0b7d27bb4 100644
--- a/tensorflow/compiler/plugin/example/executor.h
+++ b/tensorflow/compiler/plugin/example/executor.h
@@ -78,14 +78,22 @@ class ExampleExecutor : public internal::StreamExecutorInterface {
uint64 size) override;
bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
const DeviceMemoryBase &host_src,
- uint64 size) override { return false; }
+ uint64 size) override {
+ return false;
+ }
bool MemZero(Stream *stream, DeviceMemoryBase *location,
- uint64 size) override { return false; }
+ uint64 size) override {
+ return false;
+ }
bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
- uint64 size) override { return false; }
+ uint64 size) override {
+ return false;
+ }
bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
- uint64 size) override { return false; }
+ uint64 size) override {
+ return false;
+ }
// No "synchronize all activity" implemented for this platform at the moment.
bool SynchronizeAllActivity() override { return false; }
@@ -94,7 +102,9 @@ class ExampleExecutor : public internal::StreamExecutorInterface {
}
bool SynchronousMemSet(DeviceMemoryBase *location, int value,
- uint64 size) override { return false; }
+ uint64 size) override {
+ return false;
+ }
port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst,
const void *host_src, uint64 size) override;
@@ -166,10 +176,14 @@ class ExampleExecutor : public internal::StreamExecutorInterface {
}
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
- override { return nullptr; }
+ override {
+ return nullptr;
+ }
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
- override { return nullptr; }
+ override {
+ return nullptr;
+ }
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
@@ -180,19 +194,16 @@ class ExampleExecutor : public internal::StreamExecutorInterface {
return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
}
- port::StatusOr<DeviceMemoryBase>
- ExecuteGraph(const xla::Shape& shape, Args args);
+ port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape,
+ Args args);
private:
+ DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
- DeviceMemoryBase
- AllocateSingleOutput(const xla::Shape& shape);
-
- port::StatusOr<DeviceMemoryBase>
- AllocateOutputBuffer(const xla::Shape& shape);
+ port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
+ const xla::Shape &shape);
const PluginConfig plugin_config_;
-
};
} // namespace exampleplugin
diff --git a/tensorflow/compiler/plugin/example/platform.cc b/tensorflow/compiler/plugin/example/platform.cc
index 477ee3300b..00edb62364 100644
--- a/tensorflow/compiler/plugin/example/platform.cc
+++ b/tensorflow/compiler/plugin/example/platform.cc
@@ -39,14 +39,12 @@ ExamplePlatform::~ExamplePlatform() {}
Platform::Id ExamplePlatform::id() const { return kExamplePlatformId; }
-int ExamplePlatform::VisibleDeviceCount() const {
- return 1;
-}
+int ExamplePlatform::VisibleDeviceCount() const { return 1; }
const string& ExamplePlatform::Name() const { return name_; }
-port::StatusOr<StreamExecutor*>
-ExamplePlatform::ExecutorForDevice(int ordinal) {
+port::StatusOr<StreamExecutor*> ExamplePlatform::ExecutorForDevice(
+ int ordinal) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = PluginConfig();
@@ -119,8 +117,7 @@ static void InitializeExamplePlatform() {
} // namespace gputools
} // namespace perftools
-REGISTER_MODULE_INITIALIZER(
- example_platform, sep::InitializeExamplePlatform());
+REGISTER_MODULE_INITIALIZER(example_platform, sep::InitializeExamplePlatform());
DECLARE_MODULE_INITIALIZER(multi_platform_manager);
// Note that module initialization sequencing is not supported in the
diff --git a/tensorflow/compiler/plugin/example/transfer_manager.cc b/tensorflow/compiler/plugin/example/transfer_manager.cc
index 7e6f03571a..cef8a49e19 100644
--- a/tensorflow/compiler/plugin/example/transfer_manager.cc
+++ b/tensorflow/compiler/plugin/example/transfer_manager.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/plugin/example/platform_id.h"
#include "tensorflow/compiler/plugin/example/transfer_manager.h"
+#include "tensorflow/compiler/plugin/example/platform_id.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -34,16 +34,15 @@ limitations under the License.
namespace xla {
namespace exampleplugin {
-ExampleTransferManager::ExampleTransferManager() {
-}
+ExampleTransferManager::ExampleTransferManager() {}
se::Platform::Id ExampleTransferManager::PlatformId() const {
return se::exampleplugin::kExamplePlatformId;
}
Status ExampleTransferManager::TransferLiteralFromDevice(
- se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
- const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
// Tuples are a special case and contain one or more shapes inside of them to
@@ -51,8 +50,8 @@ Status ExampleTransferManager::TransferLiteralFromDevice(
if (device_shape.element_type() == TUPLE) {
*literal->mutable_shape() = literal_shape;
TF_ASSIGN_OR_RETURN(
- std::vector<se::DeviceMemoryBase> element_buffers,
- ShallowCopyTupleFromDevice(executor, source, device_shape));
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ ShallowCopyTupleFromDevice(executor, source, device_shape));
TF_RET_CHECK(element_buffers.size() ==
ShapeUtil::TupleElementCount(device_shape));
for (int64 i = 0; i < element_buffers.size(); ++i) {
@@ -62,8 +61,8 @@ Status ExampleTransferManager::TransferLiteralFromDevice(
// Recursively call TransferFromDevice to copy over the data in the
// element array.
TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
- executor, element_buffers[i], element_device_shape,
- element_literal_shape, element_literal));
+ executor, element_buffers[i], element_device_shape,
+ element_literal_shape, element_literal));
}
return Status::OK();
}
@@ -71,11 +70,11 @@ Status ExampleTransferManager::TransferLiteralFromDevice(
*literal->mutable_shape() = device_shape;
LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal);
TF_RETURN_IF_ERROR(TransferBufferFromDevice(
- executor, source, ShapeUtil::ByteSizeOf(device_shape),
- LiteralUtil::MutableInternalData(literal)));
+ executor, source, ShapeUtil::ByteSizeOf(device_shape),
+ LiteralUtil::MutableInternalData(literal)));
if (!ShapeUtil::Equal(literal_shape, device_shape)) {
literal->Swap(
- LiteralUtil::Relayout(*literal, literal_shape.layout()).get());
+ LiteralUtil::Relayout(*literal, literal_shape.layout()).get());
}
TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
return Status::OK();
@@ -83,21 +82,20 @@ Status ExampleTransferManager::TransferLiteralFromDevice(
StatusOr<std::vector<se::DeviceMemoryBase>>
ExampleTransferManager::ShallowCopyTupleFromDevice(
- se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
- const Shape& shape) {
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
TF_RET_CHECK(ShapeUtil::IsTuple(shape));
std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
nullptr);
- int64 tuple_size =
- ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*));
auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
element_pointers.data());
if (!copy_status.ok()) {
return AddStatus(
- Status(static_cast<tensorflow::error::Code>(copy_status.code()),
- copy_status.error_message()),
- "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
}
// Create a DeviceMemoryBase from each void* pointer.
@@ -107,62 +105,58 @@ ExampleTransferManager::ShallowCopyTupleFromDevice(
!ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
return FailedPrecondition("tuple contains nullptr at element %d", i);
}
- int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i),
- sizeof(void*));
+ int64 buffer_size =
+ ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*));
destination.emplace_back(element_pointers[i], buffer_size);
}
return std::move(destination);
}
Status ExampleTransferManager::TransferLiteralToDevice(
- se::StreamExecutor* executor, const Literal& literal,
- se::DeviceMemoryBase* destination) {
+ se::StreamExecutor* executor, const Literal& literal,
+ se::DeviceMemoryBase* destination) {
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(literal.shape())) {
std::vector<void*> tuple_elements_on_device;
for (const Literal& tuple_element : literal.tuple_literals()) {
se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
- GetByteSizeRequirement(tuple_element.shape()));
+ GetByteSizeRequirement(tuple_element.shape()));
TF_RETURN_IF_ERROR(
- TransferLiteralToDevice(executor, tuple_element, &allocation));
+ TransferLiteralToDevice(executor, tuple_element, &allocation));
tuple_elements_on_device.push_back(allocation.opaque());
}
return TransferBufferToDevice(
- executor, tuple_elements_on_device.size() * sizeof(void*),
- tuple_elements_on_device.data(), destination);
+ executor, tuple_elements_on_device.size() * sizeof(void*),
+ tuple_elements_on_device.data(), destination);
}
- return TransferBufferToDevice(
- executor, GetByteSizeRequirement(shape),
- LiteralUtil::InternalData(literal), destination);
+ return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
+ LiteralUtil::InternalData(literal),
+ destination);
}
-Status
-ExampleTransferManager::TransferLiteralToInfeed(se::StreamExecutor *executor,
- const Literal &literal) {
- const Shape &shape = literal.shape();
+Status ExampleTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const Literal& literal) {
+ const Shape& shape = literal.shape();
VLOG(1) << "transferring literal shape to infeed: "
- << ShapeUtil::HumanString(shape);
+ << ShapeUtil::HumanString(shape);
return Status::OK();
}
-Status
-ExampleTransferManager::TransferLiteralFromOutfeed(
- perftools::gputools::StreamExecutor* executor,
- const Shape& literal_shape,
- Literal* literal) {
- const Shape &shape = literal->shape();
+Status ExampleTransferManager::TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) {
+ const Shape& shape = literal->shape();
VLOG(1) << "transferring literal shape from outfeed: "
<< ShapeUtil::HumanString(shape);
return Status::OK();
}
-
Status ExampleTransferManager::ResetDevices(
- tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) {
return Unimplemented("Device reset not supported");
}
diff --git a/tensorflow/compiler/plugin/example/transfer_manager.h b/tensorflow/compiler/plugin/example/transfer_manager.h
index ce7d941221..0f7b81e613 100644
--- a/tensorflow/compiler/plugin/example/transfer_manager.h
+++ b/tensorflow/compiler/plugin/example/transfer_manager.h
@@ -35,46 +35,40 @@ namespace xla {
namespace exampleplugin {
class ExampleTransferManager : public TransferManager {
-public:
+ public:
ExampleTransferManager();
~ExampleTransferManager() override {}
se::Platform::Id PlatformId() const override;
- StatusOr<std::vector<se::DeviceMemoryBase>>
- ShallowCopyTupleFromDevice(
- se::StreamExecutor* executor,
- const se::DeviceMemoryBase& source,
- const Shape& shape) override;
-
- Status TransferLiteralFromDevice(
- se::StreamExecutor* executor,
- const se::DeviceMemoryBase& source,
- const Shape& device_shape,
- const Shape& literal_shape,
- Literal* literal) override;
-
- Status TransferLiteralToDevice(
- se::StreamExecutor* executor,
- const Literal& literal,
- se::DeviceMemoryBase* destination) override;
-
- Status
- TransferLiteralToInfeed(se::StreamExecutor *executor,
- const Literal &literal) override;
-
- Status TransferLiteralFromOutfeed(
- se::StreamExecutor* executor,
- const Shape& literal_shape,
- Literal* literal) override;
+ StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) override;
+
+ Status TransferLiteralFromDevice(se::StreamExecutor* executor,
+ const se::DeviceMemoryBase& source,
+ const Shape& device_shape,
+ const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status TransferLiteralToDevice(se::StreamExecutor* executor,
+ const Literal& literal,
+ se::DeviceMemoryBase* destination) override;
+
+ Status TransferLiteralToInfeed(se::StreamExecutor* executor,
+ const Literal& literal) override;
+
+ Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+ const Shape& literal_shape,
+ Literal* literal) override;
Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+ tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
int64 GetByteSizeRequirement(const Shape& shape) override;
-private:
+ private:
TF_DISALLOW_COPY_AND_ASSIGN(ExampleTransferManager);
};