aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-16 14:03:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 14:09:31 -0800
commit22d948d2739ecaadfb4091302f2050ba9cf0d0c1 (patch)
treee1d1568d8456b39c944e10dfd6d6c7da056c6821 /tensorflow/compiler/xla/service
parente2a60582bf28fa29c871736d10edad06e660776d (diff)
Add methods on TransferManager which transfer to/from device memory specified by ShapedBuffer rather than DeviceMemoryBase. This is part of a broader replacement of DeviceMemoryBase->ShapedBuffer in several XLA interfaces. With this change TransferManager no longer has to allocate memory to transfer tuples to the device. The existing methods using DeviceMemoryBase will be removed in a followup cl.
Various related changes: * Make the transfer_manager_test an xla_test so that it runs on all the platforms. * Make several of the TransferManager methods protected. * Change ScopedShapedBuffer::Allocate to only allocate device memory buffers, and not fill in the tuple index table. The index table is filled in by the transfer manager. This is a cleaner separation of concerns. PiperOrigin-RevId: 176015628
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD20
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc87
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h13
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc79
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h9
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc33
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h76
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager_test.cc161
8 files changed, 215 insertions, 263 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4ff8302568..7bb4479ce0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -568,7 +568,6 @@ cc_library(
hdrs = ["shaped_buffer.h"],
deps = [
":device_memory_allocator",
- ":transfer_manager",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -645,6 +644,7 @@ cc_library(
srcs = ["transfer_manager.cc"],
hdrs = ["transfer_manager.h"],
deps = [
+ ":shaped_buffer",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1294,24 +1294,6 @@ cc_library(
alwayslink = True, # Contains per-platform transfer manager registration
)
-tf_cc_test(
- name = "transfer_manager_test",
- srcs = ["transfer_manager_test.cc"],
- deps = [
- ":generic_transfer_manager",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_no_cuda",
- ],
-)
-
cc_library(
name = "hlo_cost_analysis",
srcs = ["hlo_cost_analysis.cc"],
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index b4fbed1562..74aa77b4f1 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -103,8 +104,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice(
// a vector of void* pointers.
std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
nullptr);
- int64 tuple_size =
- ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_);
auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
element_pointers.data());
if (!copy_status.ok()) {
@@ -121,9 +121,8 @@ GenericTransferManager::ShallowCopyTupleFromDevice(
!ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
return FailedPrecondition("tuple contains nullptr at element %lu", i);
}
- int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i),
- /*pointer_size=*/sizeof(void*));
- destination.emplace_back(element_pointers[i], buffer_size);
+ destination.emplace_back(element_pointers[i],
+ GetByteSizeRequirement(shape.tuple_shapes(i)));
}
return std::move(destination);
}
@@ -138,11 +137,79 @@ Status GenericTransferManager::WriteTuplePointersToDevice(
for (const se::DeviceMemoryBase& element : elements) {
element_pointers.push_back(element.opaque());
}
- int64 tuple_size =
- ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
+ return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
+ element_pointers.data(), region);
+}
+
+StatusOr<std::unique_ptr<Literal>>
+GenericTransferManager::TransferLiteralFromDevice(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
+ VLOG(2) << "transferring literal from device ordinal "
+ << executor->device_ordinal() << "; device shape: "
+ << ShapeUtil::HumanStringWithLayout(device_buffer.shape())
+ << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque();
+ TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+
+ std::unique_ptr<Literal> literal =
+ Literal::CreateFromShape(device_buffer.shape());
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ device_buffer.shape(),
+ [&](const Shape& subshape, const ShapeIndex& index) -> Status {
+ if (!ShapeUtil::IsTuple(subshape)) {
+ TF_RETURN_IF_ERROR(TransferBufferFromDevice(
+ executor,
+ /*source=*/device_buffer.buffer(index),
+ /*size=*/GetByteSizeRequirement(subshape),
+ /*destination=*/
+ literal->GetSubliteral(index).MutableInternalData()));
+ }
+
+ return Status::OK();
+ }));
+ return std::move(literal);
+}
+
+Status GenericTransferManager::TransferLiteralToDevice(
+ se::StreamExecutor* executor, const Literal& literal,
+ const ShapedBuffer& device_buffer) {
+ const Shape& shape = literal.shape();
+ VLOG(2) << "transferring literal shape to device: "
+ << ShapeUtil::HumanString(shape) << "; device location: "
+ << device_buffer.buffer(/*index=*/{}).opaque();
+
+ TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape()));
+ TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+
+ TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer));
- return TransferBufferToDevice(executor, tuple_size, element_pointers.data(),
- region);
+ return ShapeUtil::ForEachSubshapeWithStatus(
+ device_buffer.shape(),
+ [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
+ se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
+ if (ShapeUtil::IsArray(device_subshape)) {
+ TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
+ device_memory.size());
+ // Element is array-shaped: transfer array data to device buffer.
+ const Literal& subliteral = literal.GetSubliteral(index);
+ std::unique_ptr<Literal> relayed_out_literal;
+ const void* source;
+ if (LayoutUtil::Equal(device_subshape.layout(),
+ subliteral.shape().layout())) {
+ source = subliteral.InternalData();
+ } else {
+ // Relayout data before transferring.
+ relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
+ /*shape_index=*/{});
+ source = relayed_out_literal->InternalData();
+ }
+ return TransferBufferToDevice(
+ executor,
+ /*size=*/GetByteSizeRequirement(device_subshape), source,
+ &device_memory);
+ }
+ return Status::OK();
+ });
}
Status GenericTransferManager::TransferLiteralToDevice(
@@ -198,7 +265,7 @@ Status GenericTransferManager::ResetDevices(
}
int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
- return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
+ return ShapeUtil::ByteSizeOf(shape, pointer_size_);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index ef9a50676a..50dca6aec5 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -52,6 +52,14 @@ class GenericTransferManager : public TransferManager {
perftools::gputools::StreamExecutor* executor, const Literal& literal,
perftools::gputools::DeviceMemoryBase* destination) override;
+ StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const ShapedBuffer& device_buffer) override;
+
+ Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor,
+ const Literal& literal,
+ const ShapedBuffer& device_buffer) override;
+
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
const Literal& literal) override;
Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor,
@@ -71,6 +79,9 @@ class GenericTransferManager : public TransferManager {
const perftools::gputools::DeviceMemoryBase& source,
const Shape& shape) override;
+ int64 GetByteSizeRequirement(const Shape& shape) const override;
+
+ protected:
Status WriteTuplePointersToDevice(
perftools::gputools::StreamExecutor* executor,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
@@ -78,8 +89,6 @@ class GenericTransferManager : public TransferManager {
const Shape& shape,
perftools::gputools::DeviceMemoryBase* region) override;
- int64 GetByteSizeRequirement(const Shape& shape) const override;
-
private:
// The platform this transfer manager targets.
const perftools::gputools::Platform::Id platform_id_;
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index a57ebf59e7..a7539a1a11 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -21,17 +21,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace se = ::perftools::gputools;
namespace xla {
+using ::tensorflow::strings::Appendf;
+
/* static */ StatusOr<std::unique_ptr<ShapedBuffer>>
ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape,
const se::Platform* platform,
@@ -80,10 +82,33 @@ se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) {
return &buffers_[shape_index_to_buffer_entry_.element(index)];
}
+string ShapedBuffer::ToString() const {
+ string s = "ShapedBuffer(" + platform_->Name() + "):\n";
+ ShapeUtil::ForEachSubshape(
+ shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) {
+ string shape_str;
+ if (ShapeUtil::IsTuple(subshape)) {
+ shape_str = "tuple";
+ } else {
+ shape_str = ShapeUtil::HumanStringWithLayout(subshape);
+ }
+ const se::DeviceMemoryBase& memory = buffer(index);
+ Appendf(&s, " %s%p (%lld bytes) : %s\n",
+ string(index.size() * 2, ' ').c_str(), memory.opaque(),
+ memory.size(), shape_str.c_str());
+ });
+ return s;
+}
+
+std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
+ out << buffer.ToString();
+ return out;
+}
+
/* static */ StatusOr<std::unique_ptr<ScopedShapedBuffer>>
-ScopedShapedBuffer::Allocate(const Shape& shape,
- DeviceMemoryAllocator* allocator,
- int device_ordinal) {
+ScopedShapedBuffer::Allocate(
+ const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal,
+ const std::function<int64(const Shape&)>& shape_size_fn) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Shape must have a layout: %s",
ShapeUtil::HumanStringWithLayout(shape).c_str());
@@ -93,51 +118,17 @@ ScopedShapedBuffer::Allocate(const Shape& shape,
WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal));
// Allocate an appropriate sized buffer for each element in the shape
- // including the tuple pointer arrays. Gather tuple element addresses in
- // 'element_addresses'. These will be written in the respective tuple's array
- // of pointers on the device.
- TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager,
- TransferManager::GetForPlatform(allocator->platform()));
- ShapeTree<std::vector<se::DeviceMemoryBase>> element_addresses(shape);
+ // including the tuple pointer arrays.
for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) {
const ShapeIndex& index = pair.first;
size_t& buffer_entry = pair.second;
- TF_ASSIGN_OR_RETURN(
- se::DeviceMemoryBase memory_base,
- shaped_buffer->allocator_->Allocate(
- shaped_buffer->device_ordinal(),
- transfer_manager->GetByteSizeRequirement(
- ShapeUtil::GetSubshape(shaped_buffer->shape(), index))));
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase memory_base,
+ shaped_buffer->allocator_->Allocate(
+ shaped_buffer->device_ordinal(),
+ shape_size_fn(ShapeUtil::GetSubshape(
+ shaped_buffer->shape(), index))));
shaped_buffer->buffers_.push_back(memory_base);
buffer_entry = shaped_buffer->buffers_.size() - 1;
-
- // If this is a tuple element, then push the address on to the
- // vector of tuple element addresses.
- if (!index.empty()) {
- ShapeIndex parent_index = index;
- parent_index.pop_back();
- element_addresses.mutable_element(parent_index)->push_back(memory_base);
- }
- }
-
- // Fill in the tuple pointer arrays with the addresses of their respective
- // elements.
- TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
- allocator->platform()->ExecutorForDevice(
- shaped_buffer->device_ordinal()));
- for (const auto& pair : element_addresses) {
- const ShapeIndex& index = pair.first;
- const std::vector<se::DeviceMemoryBase>& addresses = pair.second;
- const Shape& subshape = ShapeUtil::GetSubshape(shape, index);
-
- if (addresses.empty()) {
- TF_RET_CHECK(!ShapeUtil::IsTuple(subshape) ||
- ShapeUtil::TupleElementCount(subshape) == 0);
- continue;
- }
- TF_RET_CHECK(ShapeUtil::IsTuple(subshape));
- TF_RETURN_IF_ERROR(transfer_manager->WriteTuplePointersToDevice(
- executor, addresses, subshape, shaped_buffer->mutable_buffer(index)));
}
return std::move(shaped_buffer);
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index b440948700..fa88caa13f 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
#include <memory>
+#include <ostream>
+#include <string>
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape_tree.h"
@@ -79,6 +81,8 @@ class ShapedBuffer {
void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer,
const ShapeIndex& shape_index);
+ string ToString() const;
+
protected:
// The shape of the device buffer with layout.
const Shape shape_;
@@ -99,6 +103,8 @@ class ShapedBuffer {
ShapeTree<size_t> shape_index_to_buffer_entry_;
};
+std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
+
// ShapedBuffer derived class which allocates all internal buffers on
// construction and deallocates the memory when the object is
// destructed.
@@ -109,7 +115,8 @@ class ScopedShapedBuffer : public ShapedBuffer {
// buffers (if any) are allocated and initialized to the backend-specific
// representation of an array of pointers to the tuple elements.
static StatusOr<std::unique_ptr<ScopedShapedBuffer>> Allocate(
- const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal);
+ const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal,
+ const std::function<int64(const Shape&)>& shape_size_fn);
// Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the
// deallocation of the device memory held in the shaped buffer. All device
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index fef131d19f..d5f53ad56f 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -72,6 +72,39 @@ TransferManager::GetPlatformTransferManagers() {
return it->second.manager.get();
}
+Status TransferManager::WriteTupleIndexTables(
+ perftools::gputools::StreamExecutor* executor,
+ const ShapedBuffer& device_buffer) {
+ VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at "
+ << device_buffer.buffer(/*index=*/{}).opaque()
+ << "; shape: " << ShapeUtil::HumanString(device_buffer.shape());
+
+ TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+
+ return ShapeUtil::ForEachSubshapeWithStatus(
+ device_buffer.shape(),
+ [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
+ if (ShapeUtil::IsTuple(device_subshape)) {
+ se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
+ TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
+ device_memory.size());
+
+ std::vector<se::DeviceMemoryBase> elements;
+ ShapeIndex element_index = index;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
+ ++i) {
+ element_index.push_back(i);
+ elements.push_back(device_buffer.buffer(element_index));
+ element_index.pop_back();
+ }
+ return WriteTuplePointersToDevice(executor, elements, device_subshape,
+ &device_memory);
+ }
+
+ return Status::OK();
+ });
+}
+
Status TransferManager::TransferBufferFromDevice(
se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
int64 size, void* destination) {
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index d7f85f5765..fdc123e54e 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -47,6 +48,8 @@ class TransferManager {
// executor. device_shape is the shape, including layout, of the data on the
// device, while literal_shape will be the shape for the literal. device_shape
// and literal_shape must be compatible, but need not have the same layout.
+ // TODO(b/66694934): Remove TransferLiteral* methods which accept bare
+ // DeviceMemoryBase.
virtual Status TransferLiteralFromDevice(
perftools::gputools::StreamExecutor* executor,
const perftools::gputools::DeviceMemoryBase& region,
@@ -59,6 +62,20 @@ class TransferManager {
perftools::gputools::StreamExecutor* executor, const Literal& literal,
perftools::gputools::DeviceMemoryBase* region) = 0;
+ // Transfers the data held in the given ShapedBuffer into the provided literal
+ // using the provided executor. literal_shape will be the shape for the
+ // literal. The shape of the ShapedBuffer and literal_shape must be
+ // compatible, but need not have the same layout.
+ virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const ShapedBuffer& device_buffer) = 0;
+
+ // Transfers the given literal into the previously allocated device memory
+ // represented by the given ShapedBuffer using the given executor.
+ virtual Status TransferLiteralToDevice(
+ perftools::gputools::StreamExecutor* executor, const Literal& literal,
+ const ShapedBuffer& device_buffer) = 0;
+
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
virtual Status TransferLiteralToInfeed(
@@ -97,15 +114,11 @@ class TransferManager {
const perftools::gputools::DeviceMemoryBase& source,
const Shape& shape) = 0;
- // Writes the given device-memory pointers in 'elements' to the given region
- // to construct a tuple in the platform-specific tuple representation. This
- // can handle nested tuples as well. In the nested case, the element
- // DeviceMemoryBase points to another array of pointers on the device.
- virtual Status WriteTuplePointersToDevice(
- perftools::gputools::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
- elements,
- const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0;
+ // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
+ // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
+ // ShapedBuffer is array-shaped this method does nothing.
+ Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor,
+ const ShapedBuffer& device_buffer);
// Returns all buffer pointers that the tuple `source` refers to. Unlike
// ShallowCopyTupleFromDevice, this function gather buffer pointers in nested
@@ -121,23 +134,6 @@ class TransferManager {
// region for a host-to-device transfer.
virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
- // Transfer a memory block of the given size from the device source into the
- // 'destination' buffer.
- //
- // size is the size to transfer to destination in bytes.
- virtual Status TransferBufferFromDevice(
- perftools::gputools::StreamExecutor* executor,
- const perftools::gputools::DeviceMemoryBase& source, int64 size,
- void* destination);
-
- // Transfer a memory block of the given size from 'source' buffer to the given
- // destination of the device.
- //
- // size is the size to transfer from source in bytes.
- virtual Status TransferBufferToDevice(
- perftools::gputools::StreamExecutor* executor, int64 size,
- const void* source, perftools::gputools::DeviceMemoryBase* destination);
-
typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
/////
@@ -157,6 +153,34 @@ class TransferManager {
static StatusOr<TransferManager*> GetForPlatform(
const perftools::gputools::Platform* platform);
+ protected:
+ // Transfer a memory block of the given size from the device source into the
+ // 'destination' buffer.
+ //
+ // size is the size to transfer to destination in bytes.
+ virtual Status TransferBufferFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source, int64 size,
+ void* destination);
+
+ // Transfer a memory block of the given size from 'source' buffer to the given
+ // destination of the device.
+ //
+ // size is the size to transfer from source in bytes.
+ virtual Status TransferBufferToDevice(
+ perftools::gputools::StreamExecutor* executor, int64 size,
+ const void* source, perftools::gputools::DeviceMemoryBase* destination);
+
+ // Writes the given device-memory pointers in 'elements' to the given region
+ // to construct a tuple in the platform-specific tuple representation. This
+ // can handle nested tuples as well. In the nested case, the element
+ // DeviceMemoryBase points to another array of pointers on the device.
+ virtual Status WriteTuplePointersToDevice(
+ perftools::gputools::StreamExecutor* executor,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ elements,
+ const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0;
+
private:
// The mutex that guards the platform-to-transfer manager map.
static tensorflow::mutex platform_transfer_manager_mutex_;
diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc
deleted file mode 100644
index c25a0861e9..0000000000
--- a/tensorflow/compiler/xla/service/transfer_manager_test.cc
+++ /dev/null
@@ -1,161 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/literal_test_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace se = ::perftools::gputools;
-
-namespace xla {
-
-namespace {
-
-class CpuTransferManagerTest : public ::testing::Test {
- protected:
- CpuTransferManagerTest()
- : transfer_manager_(se::host::kHostPlatformId,
- /*pointer_size=*/sizeof(void*)) {
- se::Platform* platform =
- se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
- .ValueOrDie();
- stream_exec_ =
- platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie();
- }
-
- ~CpuTransferManagerTest() override {}
-
- se::StreamExecutor* stream_exec_;
- GenericTransferManager transfer_manager_;
-};
-
-TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) {
- std::vector<uint8> storage(sizeof(uint32), '\x00');
- se::DeviceMemoryBase memptr(storage.data(), storage.size());
- std::unique_ptr<Literal> literal = Literal::CreateR0<uint32>(42);
- TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
- &memptr));
-
- CHECK_EQ(42, *reinterpret_cast<uint32*>(&storage[0]));
-}
-
-TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) {
- std::vector<uint8> storage(4 * sizeof(float), '\x00');
- se::DeviceMemoryBase memptr(storage.data(), storage.size());
- std::unique_ptr<Literal> literal =
- Literal::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
- TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
- &memptr));
-
- CHECK_EQ(1.25f, *reinterpret_cast<float*>(&storage[0]));
- CHECK_EQ(2.5f, *reinterpret_cast<float*>(&storage[sizeof(float)]));
- CHECK_EQ(-17.0f, *reinterpret_cast<float*>(&storage[2 * sizeof(float)]));
- CHECK_EQ(-20.125f, *reinterpret_cast<float*>(&storage[3 * sizeof(float)]));
-}
-
-TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) {
- std::vector<uint8> storage(16, '\x00');
- se::DeviceMemoryBase memptr(storage.data(), storage.size());
- const char* str = "0123456789abcdef";
- std::unique_ptr<Literal> literal = Literal::CreateR1U8(str);
- TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
- &memptr));
-
- CHECK_EQ('0', storage[0]);
- CHECK_EQ('8', storage[8]);
- CHECK_EQ('f', storage[15]);
-}
-
-TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) {
- std::vector<uint32> storage(1, 42);
- se::DeviceMemoryBase memptr(storage.data(),
- storage.size() * sizeof(storage[0]));
- Literal literal;
- const Shape shape = ShapeUtil::MakeShape(U32, {});
- TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
- stream_exec_, memptr, shape, shape, &literal));
-
- LiteralTestUtil::ExpectR0Equal<uint32>(42, literal);
-}
-
-TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) {
- std::vector<float> storage{1.25f, 2.5f, -17.0f, -20.125f};
- se::DeviceMemoryBase memptr(storage.data(),
- storage.size() * sizeof(storage[0]));
- Literal literal;
- const Shape shape = ShapeUtil::MakeShape(F32, {4});
- TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
- stream_exec_, memptr, shape, shape, &literal));
-
- LiteralTestUtil::ExpectR1Equal<float>({1.25, 2.5, -17.0, -20.125}, literal);
-}
-
-TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
- std::vector<uint8> storage{'k', 'l', 'm', 'n'};
- se::DeviceMemoryBase memptr(storage.data(),
- storage.size() * sizeof(storage[0]));
- Literal literal;
- const Shape shape = ShapeUtil::MakeShape(U8, {4});
- TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
- stream_exec_, memptr, shape, shape, &literal));
- CHECK_EQ("klmn", literal.u8s_string());
-}
-
-TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
- std::vector<uint64> storage{1, 5, 42};
- int64 size = storage.size() * sizeof(storage[0]);
- se::DeviceMemoryBase memptr(storage.data(), size);
-
- std::vector<uint64> dest(3, 0);
- TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr,
- size, dest.data()));
- ASSERT_EQ(1, dest[0]);
- ASSERT_EQ(5, dest[1]);
- ASSERT_EQ(42, dest[2]);
-}
-
-TEST_F(CpuTransferManagerTest, TransferBufferToDevice) {
- int64 size = 3 * sizeof(uint64);
- std::vector<uint8> storage(size, 0);
- se::DeviceMemoryBase memptr(storage.data(), size);
-
- std::vector<uint64> dest{1, 5, 42};
- TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size,
- dest.data(), &memptr));
- std::vector<uint64>* storage64 =
- reinterpret_cast<std::vector<uint64>*>(&storage);
- ASSERT_EQ(1, (*storage64)[0]);
- ASSERT_EQ(5, (*storage64)[1]);
- ASSERT_EQ(42, (*storage64)[2]);
-}
-
-// TODO(b/24679870): add similar tests for GPUs
-
-} // namespace
-
-} // namespace xla