From fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 15 Dec 2017 23:38:01 -0800 Subject: Merged commit includes the following changes: 179277894 by gunan: Run buildifier on build file. -- 179275101 by meheff: Replace DeviceMemoryBase with ShapedBuffer in XLA interfaces. Executable, TransferManager, and AllocationTracker now use ShapedBuffer to hold device memory addresses holding XLA data. Most of the change is straight-forward with the exception of AllocationTracker which was mostly rewritten (and simplified) and some refactoring in the CPU executable. Also, have ShapedBuffer hold on-host and on-device Shapes which are the shapes of the representation of the data on the host and device, respectively. This is necessary because with cl/178624364 the on-host and on-device shape may no longer be equal. -- 179265385 by A. Unique TensorFlower: Return error rather than CHECK fail in Executable::ExecuteOnStreamWrapper -- 179264551 by dandelion: Internal fixes. -- PiperOrigin-RevId: 179277894 --- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 17 +- tensorflow/compiler/xla/client/local_client.cc | 14 +- tensorflow/compiler/xla/literal_util.cc | 21 ++ tensorflow/compiler/xla/literal_util.h | 4 + .../compiler/xla/service/allocation_tracker.cc | 228 ++++++++---------- .../compiler/xla/service/allocation_tracker.h | 179 +++++--------- .../compiler/xla/service/cpu/cpu_executable.cc | 200 ++++++--------- .../compiler/xla/service/cpu/cpu_executable.h | 30 ++- .../xla/service/cpu/parallel_cpu_executable.cc | 146 +++-------- .../xla/service/cpu/parallel_cpu_executable.h | 18 +- tensorflow/compiler/xla/service/executable.cc | 14 +- tensorflow/compiler/xla/service/executable.h | 23 +- .../xla/service/generic_transfer_manager.cc | 132 ++-------- .../xla/service/generic_transfer_manager.h | 24 +- .../compiler/xla/service/gpu/gpu_executable.cc | 142 +++-------- .../compiler/xla/service/gpu/gpu_executable.h | 16 +- tensorflow/compiler/xla/service/hlo_runner.cc | 98 +++----- tensorflow/compiler/xla/service/hlo_runner.h | 45 +--- tensorflow/compiler/xla/service/interpreter/BUILD | 1 + .../compiler/xla/service/interpreter/executable.cc | 86 +++---- .../compiler/xla/service/interpreter/executable.h | 11 +- tensorflow/compiler/xla/service/local_service.cc | 4 +- tensorflow/compiler/xla/service/service.cc | 267 +++++++++------------ tensorflow/compiler/xla/service/service.h | 19 +- tensorflow/compiler/xla/service/shaped_buffer.cc | 120 +++------ tensorflow/compiler/xla/service/shaped_buffer.h | 100 ++++---- .../compiler/xla/service/transfer_manager.cc | 107 ++++++--- tensorflow/compiler/xla/service/transfer_manager.h | 100 ++++---- tensorflow/compiler/xla/tests/copy_test.cc | 11 +- tensorflow/compiler/xla/tests/dynamic_ops_test.cc | 16 +- tensorflow/compiler/xla/tests/fusion_test.cc | 3 +- tensorflow/compiler/xla/tests/hlo_test_base.cc | 22 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 21 +- .../xla/tests/local_client_execute_test.cc | 36 ++- .../compiler/xla/tests/local_client_test_base.cc | 2 +- .../compiler/xla/tests/multioutput_fusion_test.cc | 13 +- .../compiler/xla/tests/transfer_manager_test.cc | 41 +++- 37 files changed, 866 insertions(+), 1465 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 39a770ab7b..4f3f17df9c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -287,10 +287,17 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( const_cast(t->tensor_data().data()), t->tensor_data().size()); - arg_buffers[i] = - xla::ShapedBuffer::MakeArrayShapedBuffer( - shape, client->platform(), client->default_device_ordinal(), dmem) - .ConsumeValueOrDie(); + const xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape(shape); + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + arg_buffers[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), + client->default_device_ordinal()); + arg_buffers[i]->set_buffer(dmem, /*index=*/{}); arg_ptrs[i] = arg_buffers[i].get(); OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); @@ -313,7 +320,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index b051955f0f..7900246a49 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -78,14 +78,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } for (int i = 0; i < arguments.size(); ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->shape())) { + arguments[i]->on_host_shape())) { return InvalidArgument( "argument does not match shape or layout of computation parameter " "%d: expected %s, got %s", i, ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), - ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } } @@ -281,13 +281,9 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, if (allocator == nullptr) { allocator = backend().memory_allocator(); } - TF_ASSIGN_OR_RETURN( - auto scoped_buffer, - ScopedShapedBuffer::Allocate( - literal.shape(), allocator, device_ordinal, - [this](const Shape& shape) { - return backend().transfer_manager()->GetByteSizeRequirement(shape); - })); + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), allocator, device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 42c9d21149..3ae356bc11 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -404,6 +404,27 @@ std::unique_ptr Literal::Relayout( return outer_result; } +std::unique_ptr Literal::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + DimensionVector base(ShapeUtil::Rank(subshape), 0); + DimensionVector copy_size(subshape.dimensions().begin(), + subshape.dimensions().end()); + TF_CHECK_OK(result->GetSubliteral(index).Copy(GetSubliteral(index), + base, base, copy_size)); + } + }); + return result; +} + StatusOr> Literal::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (ShapeUtil::IsTuple(shape())) { diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2981f9f875..9b9972725b 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -286,6 +286,10 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ad2fee2d39..b69a6e730f 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -27,191 +27,163 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace se = ::perftools::gputools; namespace xla { -AllocationTracker::AllocationTracker() : next_handle_(1) {} - -GlobalDataHandle AllocationTracker::Register(Backend* backend, - int device_ordinal, - se::DeviceMemoryBase device_memory, - const Shape& shape, - const string& tag) { - tensorflow::mutex_lock lock(allocation_mutex_); +StatusOr AllocationTracker::Register( + std::unique_ptr shaped_buffer, const string& tag) { + tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(backend, device_ordinal, device_memory, shape, tag, - /*initial_ref_count=*/1); + return RegisterInternal(std::move(shaped_buffer), tag); } -GlobalDataHandle AllocationTracker::RegisterInternal( - Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) { +StatusOr AllocationTracker::RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) { VLOG(2) << "RegisterInternal(" << "tag: \"" << tag << "\" " - << "device_ordinal: " << device_ordinal << " " - << "device_memory: " << device_memory.opaque() << " " - << "shape: " << shape.ShortDebugString() << ")"; - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - - int64 handle; - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory.opaque()); - if (handle_it != handle_map.end()) { - handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << - (ref_count + initial_ref_count); - allocation->increment_ref_count(initial_ref_count); - } else { - handle = next_handle_++; - VLOG(2) << "ref_count: " << initial_ref_count; - InsertOrDie(&handle_map, device_memory.opaque(), handle); - auto inserted = handle_to_allocation_.emplace( - handle, MakeUnique(backend, device_ordinal, device_memory, - shape, tag, initial_ref_count)); - CHECK(inserted.second); + << "shaped_buffer: " << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); } + int64 handle = next_handle_++; + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } GlobalDataHandle result; result.set_handle(handle); + + handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); + VLOG(2) << "handle: " << handle; return result; } tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); - std::set deallocated_buffers; - TF_RETURN_IF_ERROR( - DeallocateShape(allocation->backend(), allocation->device_ordinal(), - allocation->mutable_device_memory(), allocation->shape(), - &deallocated_buffers)); - return tensorflow::Status::OK(); -} - -tensorflow::Status AllocationTracker::DeallocateShape( - Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory, - const Shape& shape, std::set* deallocated_buffers) { - VLOG(2) << "DeallocateShape(" - << "shape: \"" << shape.ShortDebugString() << "\" " - << "device_memory: " << device_memory->opaque() << ")"; - if (ContainsKey(*deallocated_buffers, device_memory->opaque())) { - // Buffer has already been deallocated. Nothing to do. - VLOG(2) << "already deallocated"; - return tensorflow::Status::OK(); + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "Unregister(" + << "handle: " << data.handle() << ")"; + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); } - // Add buffer to deallocated set so we do not try to deallocate it again - // if it is encountered again while traversing a tuple. - deallocated_buffers->insert(device_memory->opaque()); - - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory->opaque()); - if (handle_it != handle_map.end()) { - int64 handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1; - allocation->decrement_ref_count(); - if (allocation->ref_count() > 0) { - // Buffer is referred to by another allocation. Don't deallocate it. - return tensorflow::Status::OK(); - } - handle_map.erase(device_memory->opaque()); - } + // Keep a nullptr as a tombstone for unregistered handles. This enables better + // error messages. That is, "handle has been deallocated" versus "handle does + // not exist". + handle_to_shaped_buffer_.at(data.handle()).reset(); - if (ShapeUtil::IsTuple(shape)) { - // Traverse into tuple recursively deallocating buffers. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend->stream_executor(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::vector elements, - backend->transfer_manager()->ShallowCopyTupleFromDevice( - executor, *device_memory, shape)); - - TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) - << "tuple has unexpected number of elements: " << elements.size() - << " != " << ShapeUtil::TupleElementCount(shape); - for (size_t i = 0; i < elements.size(); ++i) { - VLOG(2) << "recursing onto the tuple elements"; - TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], - shape.tuple_shapes(i), - deallocated_buffers)); - } - } - - return backend->memory_allocator()->Deallocate(device_ordinal, device_memory); + return tensorflow::Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + tensorflow::mutex_lock lock(mutex_); - if (!ShapeUtil::IsTuple(allocation->shape())) { + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); } + // If the on-host representation is a tuple, then the on-device one should be + // as well. + TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); - if (ShapeUtil::IsNestedTuple(allocation->shape())) { + if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("deconstructing nested tuples not yet supported"); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::vector element_bases, - allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice( - executor, allocation->device_memory(), allocation->shape())); - std::vector element_handles; - element_handles.reserve(element_bases.size()); - for (int i = 0; i < element_bases.size(); ++i) { - element_handles.push_back(RegisterInternal( - allocation->backend(), allocation->device_ordinal(), element_bases[i], - ShapeUtil::GetSubshape(allocation->shape(), {i}), - tensorflow::strings::StrCat(allocation->tag(), ".element_", i), - /*initial_ref_count=*/2)); + for (int i = 0; + i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); + ++i) { + auto element_buffer = MakeUnique( + ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), + ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), + shaped_buffer->platform(), shaped_buffer->device_ordinal()); + element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle element_handle, + RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + + element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr AllocationTracker::Resolve( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); + tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_allocation_.find(data.handle()); - if (it == handle_to_allocation_.end()) { + auto it = handle_to_shaped_buffer_.find(data.handle()); + if (it == handle_to_shaped_buffer_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - Allocation* allocation = it->second.get(); + ShapedBuffer* shaped_buffer = it->second.get(); - if (allocation->is_deallocated()) { + if (shaped_buffer == nullptr) { return InvalidArgument("global data handle %lld was previously deallocated", data.handle()); } - return allocation; + return shaped_buffer; } -AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap( - int device_ordinal) { - if (opaque_to_handle_.size() <= device_ordinal) { - opaque_to_handle_.resize(device_ordinal + 1); +void AllocationTracker::AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + if (it == allocation_map.end()) { + allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, + /*ref_count=*/1}; + } else { + it->second.ref_count++; } - return opaque_to_handle_[device_ordinal]; +} + +Status AllocationTracker::DecrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + TF_RET_CHECK(it != allocation_map.end()); + Allocation& allocation = it->second; + TF_RET_CHECK(allocation.ref_count >= 1); + if (allocation.ref_count == 1) { + TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( + device_ordinal, &device_memory)); + allocation_map.erase(it); + } else { + allocation.ref_count--; + } + return tensorflow::Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index ebbf35b6fe..8b25cbb482 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -28,147 +28,92 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace xla { -// A global allocation in device space, tracked by the XLA service. -class Allocation { - public: - Allocation(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) - : backend_(backend), - device_ordinal_(device_ordinal), - device_memory_(device_memory), - shape_(shape), - tag_(tag), - ref_count_(initial_ref_count) {} - - Backend* backend() const { return backend_; } - int device_ordinal() const { return device_ordinal_; } - perftools::gputools::DeviceMemoryBase device_memory() const { - return device_memory_; - } - const Shape& shape() const { return shape_; } - const string& tag() const { return tag_; } - - bool is_deallocated() const { - CHECK_GE(ref_count_, 0); - return ref_count_ == 0; - } - int ref_count() const { - CHECK_GE(ref_count_, 0); - return ref_count_; - } - void increment_ref_count(int inc) { - CHECK_GT(ref_count_, 0); - CHECK_LE(ref_count_, INT_MAX - inc); - ref_count_ += inc; - } - void decrement_ref_count() { - CHECK_GT(ref_count_, 0); - --ref_count_; - } - perftools::gputools::DeviceMemoryBase* mutable_device_memory() { - return &device_memory_; - } - - private: - // The backend that the memory is allocated on. - Backend* backend_; - - // The device that the memory is allocated on. - int device_ordinal_; - - // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory_; - - // The shape of this allocation. - Shape shape_; - - // An informal description of this allocation shown in tools. - string tag_; - - // This is the number of Allocation objects which refer to this memory - // allocation. - int ref_count_; - - // Return a string representation of this allocation for debugging or logging - // purposes. - string ToString() const; -}; - // Tracks allocations for the XLA service; allocations can be registered // with shape/device/tag and resolved from a handle for later use. class AllocationTracker { public: - AllocationTracker(); + // The allocator is used for deallocating memory when allocations are + // deregistered. All registered allocations must have the same platform as the + // allocator. + AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} - // Registers device memory with a given shape, device identifier, and tag, and - // returns a corresponding handle that can be used for talking to XLA - // clients. - GlobalDataHandle Register(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag); + // Registers a shaped buffer of device memory, and returns a corresponding + // handle that can be used for talking to XLA clients. + StatusOr Register( + std::unique_ptr shaped_buffer, const string& tag); // Unregister the allocation for the given data handle. - tensorflow::Status Unregister(const GlobalDataHandle& data); + Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to an allocation, or provide an - // error status to say whether it was not found (or found, but found - // deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a shaped buffer, or provide an error + // status to say whether it was not found (or found, but found deallocated). + StatusOr Resolve(const GlobalDataHandle& data); private: - // Internal helper which resolves the given GlobalDataHandle to an Allocation. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - GlobalDataHandle RegisterInternal( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape, - const string& tag, int initial_ref_count) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Helper function which deallocates the memory buffer containing the given - // shape referred to by device_memory. Tuples are traversed recursively - // deallocating all nested buffers. The parameter deallocated_buffers contains - // the set of buffers deallocated so far stored as opaque values (void *) from - // DeviceMemoryBase. Keeping track of deallocated buffers prevents - // double-freeing of buffers which may be referred to more than once in a - // nested tuple. - tensorflow::Status DeallocateShape( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape, - std::set* deallocated_buffers) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Returns the opaque_to_handle_ map for the given device_ordinal, creating - // a new map if there is not one for the device_ordinal. - using HandleMap = std::map; - HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - tensorflow::mutex allocation_mutex_; // Guards the allocation mapping. + // Data structure encapsulating single memory allocation on the device. + struct Allocation { + // The pointer to this allocation. + perftools::gputools::DeviceMemoryBase device_memory; + + // The device that the memory is allocated on. + int device_ordinal; + + // This is the number of times this memory allocation is refered to by + // registered data handles. + int ref_count; + }; + + // Internal helper which resolves the given GlobalDataHandle to a + // ShapedBuffer. + StatusOr ResolveInternal(const GlobalDataHandle& data) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal helper which registers a shaped buffer. + StatusOr RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Adds the given device address to the allocation tracker, or if it already + // exists, then increment it's reference count. + void AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Decrements the reference count of the given device memory. Then, if it is + // zero, deallocate the memory. + Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // A map from device memory opaque value to allocation. One such map is + // maintained per device ordinal. + using AllocationMap = tensorflow::gtl::FlatMap; + + tensorflow::mutex mutex_; + + // Backend to use with this tracker. The backend supplies the memory allocator + // to use when deallocating memory. + Backend* backend_; // The next handle to assign to an allocation, guarded by the same mutex as // the mapping as they'll be mutated at the same time. - int64 next_handle_ GUARDED_BY(allocation_mutex_); + int64 next_handle_ GUARDED_BY(mutex_); - // A map from DeviceMemoryBase to handle for each device_ordinal. - std::vector opaque_to_handle_ GUARDED_BY(allocation_mutex_); + // A map from device ordinal to AllocationMap. + tensorflow::gtl::FlatMap opaque_to_allocation_map_ + GUARDED_BY(mutex_); - // Mapping from GlobalDataHandle handle to the corresponding registered - // Allocation object. - std::map> handle_to_allocation_ - GUARDED_BY(allocation_mutex_); + // A map from data handle to ShapedBuffer. + tensorflow::gtl::FlatMap> + handle_to_shaped_buffer_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index e956f478b8..028f827337 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,28 +73,6 @@ CpuExecutable::CpuExecutable( reinterpret_cast(cantFail(sym.getAddress())); } -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. address is an in-memory -// buffer address that contains some runtime XLA object. shape is its -// shape. marked_addresses is the set of live addresses to populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -148,20 +126,6 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers; - argument_buffers.reserve(arguments.size()); - for (const auto* argument : arguments) { - argument_buffers.push_back(argument->buffer(/*index=*/{})); - } - return ExecuteComputeFunction(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status CpuExecutable::ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, @@ -177,8 +141,8 @@ Status CpuExecutable::ExecuteComputeFunction( // determined by buffer analysis. // std::vector args_array; - for (se::DeviceMemoryBase arg_mem : arguments) { - args_array.push_back(arg_mem.opaque()); + for (const ShapedBuffer* argument : arguments) { + args_array.push_back(argument->root_buffer().opaque()); } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -246,11 +210,23 @@ Status CpuExecutable::ExecuteComputeFunction( } static void LogLiveAddresses( - const std::unordered_set& marked_addresses) { + tensorflow::gtl::ArraySlice buffers, + const std::vector& buffers_in_result) { + if (!VLOG_IS_ON(3)) { + return; + } + + CHECK_EQ(buffers.size(), buffers_in_result.size()); + std::vector live_out_buffers; + for (int i = 0; i < buffers.size(); ++i) { + if (buffers_in_result[i]) { + live_out_buffers.push_back(buffers[i].opaque()); + } + } VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" + << live_out_buffers.size() << " addresses:\n" << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { + live_out_buffers, ", ", [](string* out, const void* address) { tensorflow::strings::StrAppend( out, tensorflow::strings::Printf("%p", address)); }); @@ -259,13 +235,12 @@ static void LogLiveAddresses( static Status DeallocateTempBuffers( DeviceMemoryAllocator* allocator, se::Stream* stream, tensorflow::gtl::ArraySlice buffers, - const std::unordered_set& marked_addresses) { - // Keep those marked live because they are referenced by the output of the - // computation and are needed by the service. They will be deallocated by the - // service. + const std::vector& buffers_in_result) { + // Keep those buffers in the output of the marked live because they are needed + // by the service. They will be deallocated by the service. for (size_t i = 0; i < buffers.size(); ++i) { se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + if (!buffers_in_result[i] && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; TF_RETURN_IF_ERROR( @@ -276,33 +251,43 @@ static Status DeallocateTempBuffers( return Status::OK(); } -StatusOr CpuExecutable::ExecuteOnStream( +StatusOr> CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result) { se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); - - LogLiveAddresses(marked_addresses); - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); - return top_level_output; + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as + // a tuple element. The source instruction should have a + // non-parameter buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + (*buffers_in_result)[buffer_index] = true; + return Status::OK(); + })); + return std::move(result_buffer); } StatusOr> CpuExecutable::ExecuteOnStream( @@ -317,67 +302,26 @@ StatusOr> CpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); TF_RETURN_IF_ERROR(ExecuteComputeFunction( &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + buffers_in_result)); return std::move(result_buffer); } -StatusOr -CpuExecutable::ExecuteAsyncOnStream( +StatusOr> CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -393,29 +337,25 @@ CpuExecutable::ExecuteAsyncOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - LogLiveAddresses(marked_addresses); + LogLiveAddresses(buffers, buffers_in_result); host_stream->EnqueueTask([this, run_options, arguments, buffers, - marked_addresses, memory_allocator, stream]() { + buffers_in_result, memory_allocator, stream]() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, buffers, /*hlo_execution_profile=*/nullptr)); TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); + buffers_in_result)); }); - return top_level_output; + return std::move(result_buffer); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 17ee2d673e..50443a5995 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,21 +55,14 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class CpuExecutable : public Executable { // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -122,6 +108,18 @@ class CpuExecutable : public Executable { buffers, HloExecutionProfile* hlo_execution_profile); + // Create a ShapedBuffer for holding the result of the computation. The + // addresses (DeviceMemoryBases) are set according to buffer assignment. + // 'buffers_in_result' should point to a vector of the same size as + // 'allocated_buffers'. An element in buffers_in_result is set to true if the + // corresponding buffer is live out of the computation (and thus contained in + // the returned ShapedBuffer). + StatusOr> CreateResultShapedBuffer( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result); + // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 0077e344e2..d1b88b27f0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -376,19 +376,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); - } - return ExecuteComputeFunctions(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to // profile. std::vector* profile_counters = nullptr; @@ -428,8 +415,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // just copy the existing buffer into the map containing instruction // results.. if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&results, instruction, - arguments[instruction->parameter_number()].opaque()); + InsertOrDie( + &results, instruction, + arguments[instruction->parameter_number()]->root_buffer().opaque()); } else if (instruction->opcode() == HloOpcode::kConstant) { unsigned char* aligned_data = FindOrDie(aligned_constants_, instruction).get(); @@ -461,69 +449,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( return Status::OK(); } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations( - assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, - stream->parent()->device_ordinal(), - &device_allocations)); - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - const BufferAllocation::Index result_index = result_slice.index(); - VLOG(3) << "result index: " << result_index; - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - run_options, arguments, device_allocations, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - MarkLiveAddressesInOutput(device_allocations[result_index].opaque(), - result_shape(), &marked_addresses); - - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked - // live because they are referenced by the output of the computation - // and are needed by the service. They will be deallocated by the - // service. - for (size_t i = 0; i < device_allocations.size(); ++i) { - auto alloc = device_allocations[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && - alloc.opaque() != nullptr) { - VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc)); - } - } - - return device_allocations[result_index]; -} - StatusOr> ParallelCpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -536,9 +461,9 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); @@ -549,37 +474,30 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( // Copy DeviceMemoryBase values which into the respective location in // ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + + // The points to set is unambiguous so the set should be a singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as a + // tuple element. The source instruction should have a non-parameter + // buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + buffers_in_result[buffer_index] = true; + return Status::OK(); + })); // Free all buffers not in the result. for (size_t i = 0; i < buffers.size(); ++i) { @@ -595,10 +513,10 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( return std::move(result_buffer); } -StatusOr +StatusOr> ParallelCpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index d65e3f42f3..90ac94ef92 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -59,21 +59,14 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index ad5d5ead00..c50aaec572 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,23 +26,23 @@ limitations under the License. namespace xla { -StatusOr> +StatusOr>> Executable::ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); + std::vector> return_values(run_options.size()); + if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(auto result, + TF_ASSIGN_OR_RETURN(return_values[0], ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); - return std::vector({result}); + return std::move(return_values); } - std::vector return_values( - run_options.size()); for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched @@ -54,7 +54,7 @@ Executable::ExecuteOnStreams( TF_RET_CHECK(options.stream() != nullptr); TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); } - return return_values; + return std::move(return_values); } Status Executable::DumpSessionModule() { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index cb9ee47dc6..23864dda78 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -61,16 +61,7 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. // - // Returns the device memory region that a successful execution would - // populate. - virtual StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) = 0; - - // Overload of ExecuteOnStream which returns and takes arguments as - // ShapedBuffers. Used for LocalService execution. + // Returns a shaped buffer containing the result of the computation. virtual StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -78,21 +69,19 @@ class Executable { // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr ExecuteAsyncOnStream( + virtual StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) = 0; + tensorflow::gtl::ArraySlice arguments) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr> - ExecuteOnStreams( + virtual StatusOr>> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any @@ -224,7 +213,7 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; stream->ThenStopTimer(timer.get()); - SE_CHECK_OK(stream->BlockHostUntilDone()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); VLOG(1) << "done with block-host-until-done"; // Merge in run-time profile information from execution_profile. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 74aa77b4f1..271a856efd 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -51,83 +51,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; } -Status GenericTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { - VLOG(2) << "transferring literal shape from device: " - << ShapeUtil::HumanString(literal_shape) - << "; device location: " << source.opaque(); - TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); - - // Tuples are a special case and contain one or more shapes inside of them to - // an arbitrary nesting depth. - if (device_shape.element_type() == TUPLE) { - *literal->mutable_shape() = literal_shape; - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); - TF_RET_CHECK(element_buffers.size() == - ShapeUtil::TupleElementCount(device_shape)); - for (int64 i = 0; i < element_buffers.size(); ++i) { - const Shape& element_device_shape = device_shape.tuple_shapes(i); - const Shape& element_literal_shape = literal_shape.tuple_shapes(i); - Literal* element_literal = literal->add_tuple_literals(); - // Recursively call TransferFromDevice to copy over the data in the - // element array. - TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], /*device_shape=*/element_device_shape, - /*literal_shape=*/element_literal_shape, element_literal)); - } - return Status::OK(); - } - - *literal->mutable_shape() = device_shape; - literal->Reserve(ShapeUtil::ElementsIn(device_shape)); - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/literal->MutableInternalData())); - if (!ShapeUtil::Equal(literal_shape, device_shape)) { - *literal = std::move(*literal->Relayout(literal_shape.layout())); - } - TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); - return Status::OK(); -} - -StatusOr> -GenericTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - // For devices which use the GenericTransferManager, a tuple is stored as an - // array of pointers to buffers. Copy the contents of the tuple buffer into - // a vector of void* pointers. - std::vector element_pointers(ShapeUtil::TupleElementCount(shape), - nullptr); - int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); - auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, - element_pointers.data()); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); - } - - // Create a DeviceMemoryBase from each void* pointer. - std::vector destination; - for (size_t i = 0; i < element_pointers.size(); ++i) { - if (element_pointers[i] == nullptr && - !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %lu", i); - } - destination.emplace_back(element_pointers[i], - GetByteSizeRequirement(shape.tuple_shapes(i))); - } - return std::move(destination); -} - -Status GenericTransferManager::WriteTuplePointersToDevice( +Status GenericTransferManager::WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { @@ -145,16 +69,19 @@ StatusOr> GenericTransferManager::TransferLiteralFromDevice( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " - << executor->device_ordinal() << "; device shape: " - << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) - << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); + std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.shape()); + Literal::CreateFromShape(device_buffer.on_host_shape()); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (!ShapeUtil::IsTuple(subshape)) { TF_RETURN_IF_ERROR(TransferBufferFromDevice( @@ -175,16 +102,22 @@ Status GenericTransferManager::TransferLiteralToDevice( const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) << "; device location: " - << device_buffer.buffer(/*index=*/{}).opaque(); + << ShapeUtil::HumanString(shape) + << "; device buffer: " << device_buffer; + + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); - TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (ShapeUtil::IsArray(device_subshape)) { @@ -212,33 +145,6 @@ Status GenericTransferManager::TransferLiteralToDevice( }); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { - const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) - << "; device location: " << destination->opaque(); - - if (ShapeUtil::IsTuple(literal.shape())) { - std::vector tuple_elements_on_device; - for (const Literal& tuple_element : literal.tuple_literals()) { - se::DeviceMemoryBase allocation = executor->AllocateArray( - GetByteSizeRequirement(tuple_element.shape())); - TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); - tuple_elements_on_device.push_back(allocation.opaque()); - } - return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); - } - - return TransferBufferToDevice(executor, - /*size=*/GetByteSizeRequirement(shape), - /*source=*/literal.InternalData(), destination); -} - Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { return Unimplemented("Generic transfer to Infeed"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 50dca6aec5..63a7c820cf 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -42,16 +42,6 @@ class GenericTransferManager : public TransferManager { perftools::gputools::Platform::Id PlatformId() const override; - Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* destination) override; - StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; @@ -62,9 +52,6 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; @@ -73,16 +60,13 @@ class GenericTransferManager : public TransferManager { tensorflow::gtl::ArraySlice executors) override; - StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status WriteTuplePointersToDevice( + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; + + Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index b802ae9c7a..366d87e9c3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -203,84 +203,6 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } -StatusOr GpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - BufferAllocations::Builder buffer_allocations_builder; - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_entry_computation_parameter()) { - buffer_allocations_builder.RegisterBuffer( - i, arguments[allocation.parameter_number()]); - } - } - se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN( - auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); - - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); - TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, - block_host_until_done, - hlo_execution_profile)); - - HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice output_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase output_buffer_address = - buffer_allocations->GetDeviceAddress(output_slice.index()); - - if (ShapeUtil::IsTuple(root->shape())) { - std::set referred_by_output; - if (GetRootPointsToSet().IsAmbiguous()) { - // The points-to set of the root is ambiguous so we need to examine the - // result data to determine which buffers are contained in the result. - TF_ASSIGN_OR_RETURN( - TransferManager * transfer_manager, - TransferManager::GetForPlatform(executor->platform())); - TF_ASSIGN_OR_RETURN(referred_by_output, - transfer_manager->GatherBufferPointersFromTuple( - executor, output_buffer_address, root->shape())); - } else { - // The points-to set of the root is unambiguous so it's known statically - // which buffers are in the result. Gather these buffers using the root's - // points-to set. - TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus( - [&referred_by_output, &buffer_allocations, this]( - const ShapeIndex& /*index*/, - const PointsToSet::BufferList& buffers) { - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction produced - // the array at this element. - CHECK_EQ(1, buffers.size()); - HloInstruction* hlo = buffers[0]->instruction(); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(hlo, buffers[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - referred_by_output.insert( - buffer_allocations->GetDeviceAddress(slice.index())); - return Status::OK(); - })); - } - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(referred_by_output, *assignment_)); - } else { - // If the computation result is not a tuple, we can delete all temporary - // buffers that are not the output. - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown({output_buffer_address}, *assignment_)); - } - return output_buffer_address; -} - StatusOr> GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -298,7 +220,7 @@ StatusOr> GpuExecutable::ExecuteOnStream( if (allocation.is_entry_computation_parameter()) { auto param_no = allocation.parameter_number(); buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->buffer(/*index=*/{})); + i, arguments[param_no]->root_buffer()); } } se::StreamExecutor* executor = run_options->stream()->parent(); @@ -316,50 +238,46 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); auto shaped_buffer = MakeUnique( - root->shape(), executor->platform(), device_ordinal); + root->shape(), root->shape(), executor->platform(), device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR( - shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points-to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(slice.index()); - CHECK(!src_base.is_null() || src_base.size() == 0); - shaped_buffer->mutable_buffers()->push_back(src_base); - *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; - - buffers_in_result.insert(src_base); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points-to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(slice.index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + *device_memory = src_base; + buffers_in_result.insert(src_base); + return Status::OK(); + })); TF_RETURN_IF_ERROR( buffer_allocations->TearDown(buffers_in_result, *assignment_)); return std::move(shaped_buffer); } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +StatusOr> GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e7307e07c0..00da64dfad 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -72,24 +72,16 @@ class GpuExecutable : public Executable { // empty, in which case compilation is left up to the GPU driver. const std::vector& cubin() const { return cubin_; } - // Both overloads of ExecuteOnStream will fail if the compute capability of - // the stream doesn't match the compute capability passed to this object's - // constructor. - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - + // ExecuteOnStream will fail if the compute capability of the stream doesn't + // match the compute capability passed to this object's constructor. StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; const Status EqualOrFail(const Executable& executable) { // TODO(b/62952745) Implement equality test on GPU executable. diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index a6101bbe60..7b3a8cef97 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -112,17 +112,12 @@ HloRunner::HloRunner(se::Platform* platform) { VLOG(1) << "Created HloRunner for platform: " << platform->Name(); } -HloRunner::~HloRunner() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} +HloRunner::~HloRunner() {} -StatusOr HloRunner::Execute( +StatusOr> HloRunner::ExecuteInternal( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - Shape* result_shape, bool run_hlo_passes) { + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { if (run_hlo_passes) { TF_ASSIGN_OR_RETURN( module, backend().compiler()->RunHloPasses( @@ -137,6 +132,7 @@ StatusOr HloRunner::Execute( stream.Init(); ExecutableRunOptions run_options; + run_options.set_device_ordinal(backend().default_device_ordinal()); run_options.set_stream(&stream); run_options.set_allocator(backend().memory_allocator()); run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); @@ -146,73 +142,35 @@ StatusOr HloRunner::Execute( ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + // Copy arguments to device. + std::vector> argument_buffers; + std::vector argument_buffer_ptrs; + for (Literal* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } + std::unique_ptr argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), run_options.allocator(), + run_options.device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.parent(), *argument, *argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs.push_back(argument_buffers.back().get()); } - return result; -} - -StatusOr HloRunner::TransferToDevice( - const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; -} - -StatusOr> HloRunner::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return std::move(literal); -} + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, + /*hlo_execution_profile=*/nullptr)); -StatusOr> HloRunner::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - Shape result_shape; + // Create a ScopedShapedBuffer of the result to manage deallocation. This will + // deallocate all the device memory when it goes out of scope. TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape, run_hlo_passes)); - return TransferFromDevice(result_shape, device_base); + std::unique_ptr scoped_result, + ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); + + return backend().transfer_manager()->TransferLiteralFromDevice( + stream.parent(), *scoped_result); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a65c66fd4b..d4b221fb52 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -78,30 +78,7 @@ class HloRunner { template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, - bool run_hlo_passes = true); - - // Executes the given module and returns a global data handle. - StatusOr Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape, bool run_hlo_passes = true); - - // Transfers the given literal to the device and returns the data handle. - StatusOr TransferToDevice( - const Literal& literal); - - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - StatusOr> TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); - - // Executes the given module and return the result as a Literal. - StatusOr> ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the @@ -112,9 +89,12 @@ class HloRunner { Backend& backend(); private: - struct EigenThreadPoolWrapper; + StatusOr> ExecuteInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true); - std::vector allocations_; + struct EigenThreadPoolWrapper; std::unique_ptr thread_pool_wrapper_; @@ -124,15 +104,14 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes) { - std::vector arguments; - for (const auto& literal : literals) { - TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, - TransferToDevice(*literal)); - arguments.push_back(argument); + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + for (const auto& argument : arguments) { + argument_pointers.push_back(&*argument); } - return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); + return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 2704a805a9..0819ab3b90 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 293cc2007e..b01fcccdb4 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -47,44 +48,18 @@ InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::~InterpreterExecutable() {} -static se::DeviceMemoryBase AllocateSingleOutput( - sep::InterpreterExecutor* executor, const Literal& literal) { - int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); - void* buf = executor->Allocate(size); - const void* src = literal.InternalData(); - memcpy(buf, src, size); - return se::DeviceMemoryBase(buf, size); -} - -static se::DeviceMemoryBase AllocateOutputBuffer( - sep::InterpreterExecutor* executor, const Literal& literal) { - const Shape& shape = literal.shape(); - if (shape.element_type() != xla::TUPLE) { - return AllocateSingleOutput(executor, literal); - } else { - int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); - void** buf = reinterpret_cast(executor->Allocate(size)); - void** buf_rc = buf; - for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { - se::DeviceMemoryBase out = - AllocateSingleOutput(executor, literal.tuple_literals(n)); - *buf++ = out.opaque(); - } - - return se::DeviceMemoryBase(buf_rc, size); - } -} - -StatusOr InterpreterExecutable::ExecuteOnStream( +StatusOr> InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + const se::Platform* platform = executor->platform(); VLOG(1) << "Execute " << module().name(); if (VLOG_IS_ON(2)) { for (const auto& a : arguments) { - VLOG(2) << "-- argument " << a.opaque(); + VLOG(2) << "-- argument " << *a; } } @@ -96,32 +71,32 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } - // Create the arguments as an vector of XLA literals + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); + + // Transform the ShapedBuffer arguments into literals which the evaluator + // consumes. std::vector> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - // Create the input literal for the parameter - HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); - - // Copy in the data from the stream_executor buffers - void* buffer = arg_literals.back()->MutableInternalData(); - memcpy(buffer, arguments[p].opaque(), - ShapeUtil::ByteSizeOf(param->shape())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. HloEvaluator evaluator; TF_ASSIGN_OR_RETURN( - std::unique_ptr output, + std::unique_ptr result_literal, evaluator.Evaluate>(*computation, arg_literals)); - // Copy the result into the return buffer - perftools::gputools::StreamExecutor* executor(stream->parent()); - sep::InterpreterExecutor* interpreter_executor( - static_cast(executor->implementation())); - - se::DeviceMemoryBase ret = - AllocateOutputBuffer(interpreter_executor, *(output.get())); + // Transform the result literal back into a ShapedBuffer. + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + transfer_manager->AllocateShapedBuffer( + result_literal->shape(), run_options->allocator(), + run_options->device_ordinal())); + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + executor, *result_literal, *result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -131,20 +106,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); } - return ret; -} - -StatusOr> InterpreterExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Interpreter."); + return std::move(result); } -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr> +InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 0e87eb90bf..410110a1ad 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -43,21 +43,14 @@ class InterpreterExecutable : public Executable { InterpreterExecutable(std::unique_ptr hlo_module); ~InterpreterExecutable() override; - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 06f43bd3cb..4071b948a5 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -118,10 +118,8 @@ StatusOr> LocalService::CompileExecutable( TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); - std::vector argument_buffers( - argument_layouts.size()); return BuildExecutable(versioned_handle, std::move(module_config), - argument_buffers, execute_backend_.get(), executor); + execute_backend_.get(), executor); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 9d78e6a2b2..e77a46128b 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -60,41 +60,32 @@ namespace xla { namespace { -// Copies the contents of an Allocation into a Literal proto. -tensorflow::Status LiteralFromAllocation(const Allocation* allocation, - const Shape& literal_shape, - Literal* literal) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - return allocation->backend()->transfer_manager()->TransferLiteralFromDevice( - executor, allocation->device_memory(), allocation->shape(), literal_shape, - literal); -} - // Records the arguments used to invoke a computation in a SessionModule // proto. tensorflow::Status RecordArguments( - const tensorflow::gtl::ArraySlice arg_allocations, + const tensorflow::gtl::ArraySlice arguments, + se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { module->clear_arguments(); - for (const Allocation* allocation : arg_allocations) { - Literal argument; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, allocation->shape(), &argument)); - *module->add_arguments() = argument.ToProto(); + for (const ShapedBuffer* argument : arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, *argument)); + *module->add_arguments() = literal->ToProto(); } return tensorflow::Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const Allocation* result_allocation, +tensorflow::Status RecordResult(const ShapedBuffer& result, + se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); - Literal result; - TF_RETURN_IF_ERROR(LiteralFromAllocation( - result_allocation, result_allocation->shape(), &result)); - *module->mutable_result() = result.ToProto(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, result)); + *module->mutable_result() = literal->ToProto(); return tensorflow::Status::OK(); } @@ -152,7 +143,9 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) - : options_(options), execute_backend_(std::move(execute_backend)) { + : options_(options), + allocation_tracker_(execute_backend.get()), + execute_backend_(std::move(execute_backend)) { CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { @@ -235,35 +228,33 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal) { - std::vector allocations; + int device_ordinal) { + std::vector shaped_buffers; for (size_t i = 0; i < arguments.size(); ++i) { - auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); - if (!allocation_status.ok()) { - return Status(allocation_status.status().code(), - StrCat(allocation_status.status().error_message(), ", ", + auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); + if (!buffer_status.ok()) { + return Status(buffer_status.status().code(), + StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const Allocation* allocation = allocation_status.ValueOrDie(); + const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); // Verify allocation is same platform and device as the execution. - if (allocation->backend() != backend || - allocation->device_ordinal() != device_ordinal) { + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != device_ordinal) { return InvalidArgument( - "argument %lu is on device %s but computation will be executed " + "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, - allocation->backend() - ->device_name(allocation->device_ordinal()) - .c_str(), - backend->device_name(device_ordinal).c_str()); + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(device_ordinal).c_str()); } - allocations.push_back(allocation); + shaped_buffers.push_back(shaped_buffer); } - return allocations; + return shaped_buffers; } StatusOr> Service::CreateModuleConfig( @@ -325,11 +316,11 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { - argument_shapes.push_back(&arg->shape()); + argument_shapes.push_back(&arg->on_host_shape()); } return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } @@ -398,8 +389,6 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, se::StreamExecutor* executor) { VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, versioned_handle.ToString().c_str()); @@ -447,8 +436,6 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile) { std::shared_ptr executable = @@ -471,8 +458,8 @@ StatusOr> Service::BuildAndCacheExecutable( HloModuleConfig original_module_config = *module_config; TF_ASSIGN_OR_RETURN( std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), arguments, - backend, executor)); + BuildExecutable(versioned_handle, std::move(module_config), backend, + executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -489,9 +476,7 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -547,7 +532,7 @@ Service::ExecuteParallelAndRegisterResult( // Asynchronously launch the computation. TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, + std::unique_ptr result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); if (replica == 0 && profile != nullptr) { @@ -557,9 +542,10 @@ Service::ExecuteParallelAndRegisterResult( // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { - result_handles.push_back(allocation_tracker_.Register( - backend, replicas[0]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle handle, + allocation_tracker_.Register(std::move(result), result_tags[i])); + result_handles.push_back(handle); } } } @@ -627,8 +613,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { // Set up streams. @@ -653,6 +638,7 @@ StatusOr Service::ExecuteAndRegisterResult( for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); + options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( @@ -662,25 +648,23 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - perftools::gputools::DeviceMemoryBase result; + std::unique_ptr result; if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN( - result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); + result, + executable->ExecuteOnStreamWrapper>( + &run_options[0], profile, arguments)); } else { // TODO(b/69985541): Support profiling also on this path. - std::vector< - tensorflow::gtl::ArraySlice> + std::vector> repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); - result = results[0]; + result = std::move(results[0]); } - return allocation_tracker_.Register(backend, executor->device_ordinal(), - result, executable->result_shape(), - result_tag); + return allocation_tracker_.Register(std::move(result), result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -694,7 +678,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -751,19 +735,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // In the case of partitioned computations, assume all arguments go on the // zeroth core. TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(request.arguments(), executors[0]->device_ordinal())); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, + CreateModuleConfig(*program_shape, arguments, request.execution_options())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -866,35 +845,30 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - TF_ASSIGN_OR_RETURN( std::shared_ptr executable, BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), + execute_backend_.get(), execute_backend_->default_stream_executor(), result->mutable_profile())); if (executable->dumping()) { executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR( - RecordArguments(arg_allocations, executable->session_module())); + TF_RETURN_IF_ERROR(RecordArguments( + arguments, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( @@ -905,10 +879,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const Allocation* result_allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.Resolve(result->output())); - TF_RETURN_IF_ERROR( - RecordResult(result_allocation, executable->session_module())); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); TF_RETURN_IF_ERROR(executable->DumpSessionModule()); } @@ -934,31 +909,24 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), - &profile)); + BuildAndCacheExecutable( + versioned_handle, std::move(module_config), execute_backend_.get(), + execute_backend_->default_stream_executor(), &profile)); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -973,7 +941,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.push_back(std::move(stream)); } - perftools::gputools::DeviceMemoryBase result_data; + std::unique_ptr result_buffer; for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); @@ -986,19 +954,19 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, options, execute_backend_->StreamBorrower()); TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase this_result_data, + std::unique_ptr this_result_buffer, executable->ExecuteAsyncOnStream(&service_options, arguments)); // Take the first result. - if (result_data == nullptr) { - result_data = this_result_data; + if (result_buffer == nullptr) { + result_buffer = std::move(this_result_buffer); } } - auto output = allocation_tracker_.Register( - execute_backend_.get(), execute_backend_->default_device_ordinal(), - result_data, executable->result_shape(), - "result of " + user_computation->name()); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle output, + allocation_tracker_.Register(std::move(result_buffer), + "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1025,23 +993,35 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.Resolve(arg->data())); - const Shape* literal_shape; + const Shape* return_shape; if (arg->has_shape_with_layout()) { if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { return InvalidArgument("shape_with_layout must have layout if present."); } - literal_shape = &arg->shape_with_layout(); + return_shape = &arg->shape_with_layout(); } else { - literal_shape = &allocation->shape(); + return_shape = &shaped_buffer->on_host_shape(); } - Literal literal; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, *literal_shape, &literal)); - *result->mutable_literal() = literal.ToProto(); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + execute_backend_->transfer_manager()->TransferLiteralFromDevice( + executor, *shaped_buffer)); + + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, + result_literal->shape())) { + *result->mutable_literal() = result_literal->ToProto(); + } else { + *result->mutable_literal() = + result_literal->Relayout(*return_shape)->ToProto(); + } return tensorflow::Status::OK(); } @@ -1052,12 +1032,9 @@ namespace { std::unique_ptr CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { auto clone = MakeUnique( - shaped_buffer.shape(), shaped_buffer.platform(), device_ordinal); - ShapeUtil::ForEachSubshape( - shaped_buffer.shape(), [&clone, &shaped_buffer](const Shape& /*subshape*/, - const ShapeIndex& index) { - clone->AddBufferAtIndex(shaped_buffer.buffer(index), index); - }); + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), + shaped_buffer.platform(), device_ordinal); + clone->buffers() = shaped_buffer.buffers(); return clone; } @@ -1082,22 +1059,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, int master_device_ordinal = replicas[0]->device_ordinal(); TF_ASSIGN_OR_RETURN( std::unique_ptr shaped_buffer, - ShapedBuffer::Allocate( - execute_backend_->transfer_manager()->HostShapeToDeviceShape(shape), - execute_backend_->memory_allocator(), master_device_ordinal, - [this](const Shape& shape) { - return execute_backend_->transfer_manager()->GetByteSizeRequirement( - shape); - })); - - // The allocation tracker only keeps track of the top-level buffer of the - // shape so pass in the buffer at shape index {}. - // TODO(b/37515654): Allocation tracker should hold a ShapedBuffer. - *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), master_device_ordinal, - shaped_buffer->buffer(/*index=*/{}), shape, - StrCat("TransferToServer literal of shape ", - ShapeUtil::HumanString(shape))); + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), master_device_ordinal)); // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { @@ -1117,6 +1080,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, executor, literal, *clone)); } } + TF_ASSIGN_OR_RETURN( + *result->mutable_data(), + allocation_tracker_.Register(std::move(shaped_buffer), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); + return tensorflow::Status::OK(); } @@ -1282,9 +1251,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.Resolve(arg->data())); - *result->mutable_shape() = allocation->shape(); + *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47f4f0ade5..f962d0cdc7 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -250,7 +250,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options); protected: @@ -265,10 +265,10 @@ class Service : public ServiceInterface { // Resolves the given argument handles in the allocation tracker and returns // the corresponding allocations. The function also verifies that each - // allocation matches the given backend and device ordinal. - StatusOr> ResolveAndValidateArguments( + // allocation matches the execution platform and device ordinal. + StatusOr> ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal); + int device_ordinal); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -281,8 +281,6 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor); // Same as BuildExecutable() above, but builds a list of Executables for the @@ -299,8 +297,6 @@ class Service : public ServiceInterface { StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile); @@ -310,8 +306,7 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile); @@ -320,9 +315,7 @@ class Service : public ServiceInterface { // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index aa0a24a283..c679d401c3 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -34,86 +34,32 @@ namespace xla { using ::tensorflow::strings::Appendf; -/* static */ StatusOr> -ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, - const se::Platform* platform, - int device_ordinal, - const se::DeviceMemoryBase& buffer) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Shape must be an array: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - auto shaped_buffer = - MakeUnique(shape, platform, device_ordinal); - *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = - 0; - *shaped_buffer->mutable_buffers() = {buffer}; - return std::move(shaped_buffer); -} - -/* static */ StatusOr> ShapedBuffer::Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = WrapUnique( - new ShapedBuffer(shape, allocator->platform(), device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase memory_base, - allocator->Allocate(shaped_buffer->device_ordinal(), - shape_size_fn(ShapeUtil::GetSubshape( - shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - } - - return std::move(shaped_buffer); -} - -ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, - int device_ordinal) - : shape_(shape), +ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + const se::Platform* platform, int device_ordinal) + : on_host_shape_(on_host_shape), + on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - shape_index_to_buffer_entry_(shape) {} + buffers_(on_device_shape) {} void ShapedBuffer::clear() { - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. - memory_base = se::DeviceMemoryBase(); + pair.second = se::DeviceMemoryBase(); } } -void ShapedBuffer::AddBufferAtIndex( - const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index) { - *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = - buffers().size(); - mutable_buffers()->push_back(buffer); -} - -const se::DeviceMemoryBase& ShapedBuffer::buffer( - const ShapeIndex& index) const { - return buffers_[shape_index_to_buffer_entry_.element(index)]; -} - -se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { - return &buffers_[shape_index_to_buffer_entry_.element(index)]; -} - string ShapedBuffer::ToString() const { - string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + string s = tensorflow::strings::StrCat( + "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), + ", on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( - shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + on_device_shape(), + [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; if (ShapeUtil::IsTuple(subshape)) { shape_str = "tuple"; @@ -133,34 +79,24 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr unscoped_buffer, - ShapedBuffer::Allocate(shape, allocator, device_ordinal, shape_size_fn)); - return MakeScoped(unscoped_buffer.get(), allocator); -} - /* static */ StatusOr> ScopedShapedBuffer::MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())); + shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), + allocator, shaped_buffer->device_ordinal())); scoped_buffer->buffers_ = shaped_buffer->buffers(); - scoped_buffer->shape_index_to_buffer_entry_ = - shaped_buffer->shape_index_to_buffer_entry(); - shaped_buffer->clear(); return std::move(scoped_buffer); } -ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, +ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(shape, allocator->platform(), device_ordinal), + : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), + device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::~ScopedShapedBuffer() { @@ -168,7 +104,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. std::set deallocated_opaques; - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { + se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && deallocated_opaques.count(memory_base.opaque()) == 0) { deallocated_opaques.insert(memory_base.opaque()); @@ -179,13 +116,10 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = - MakeUnique(shape(), platform(), device_ordinal()); - - *shaped_buffer->mutable_buffers() = buffers(); - *shaped_buffer->mutable_shape_index_to_buffer_entry() = - shape_index_to_buffer_entry(); + auto shaped_buffer = MakeUnique( + on_host_shape(), on_device_shape(), platform(), device_ordinal()); + shaped_buffer->buffers() = buffers(); clear(); return shaped_buffer; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index ca8bfff674..f570ebb9cb 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -31,69 +31,68 @@ limitations under the License. namespace xla { // Class which encapsulates a buffer or set of buffers containing data of a -// particular XLA shape. Used for zero-copy execution interface for a -// XLA client running in the same process as the service (LocalClient), +// particular XLA shape. class ShapedBuffer { public: - // Convenience method which creates a ShapedBuffer of array shape (not a - // tuple). Its single buffer pointer is set to the given value "buffer". The - // given buffer must be large enough to store the given shape as given by - // ShapeUtil::ByteSizeOf. - static StatusOr> MakeArrayShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); - - // Return a newly allocated ShapedBuffer of an arbitrary shape. Array buffers - // (leaves in the shape) are allocated and uninitialized. Tuple buffers (if - // any) are allocated and initialized to the backend-specific representation - // of an array of pointers to the tuple elements. - static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn); - - ShapedBuffer(const Shape& shape, + // Construct a ScopedShapedBuffer with null DeviceMemoryBases at each + // index. The shape of the data on the host and the device may differ because + // the device may have a different representation for different data + // types. Therefore, both the on-host and on-device shape are required. The + // on-device shape determines the number of device allocations + // (DeviceMemoryBase) held by the ShapedBuffer. + ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const perftools::gputools::Platform* platform, int device_ordinal); - const Shape& shape() const { return shape_; } + // Returns the shape of the on-host representation of the data held by this + // ShapedBuffer. + const Shape& on_host_shape() const { return on_host_shape_; } + + // Returns the shape of the on-device representation of the data held by this + // ShapedBuffer. + const Shape& on_device_shape() const { return on_device_shape_; } + const perftools::gputools::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } + // Return the root buffer of the shape (shape index {}). + const perftools::gputools::DeviceMemoryBase& root_buffer() const { + return buffer(/*index=*/{}); + } + // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const; - perftools::gputools::DeviceMemoryBase* mutable_buffer( - const ShapeIndex& index); - - // Returns the underlying structure which stores the buffer pointers. - const std::vector& buffers() const { - return buffers_; + const ShapeIndex& index) const { + return buffers_.element(index); } - std::vector* mutable_buffers() { - return &buffers_; + + // Sets the device memory buffer at the given index. + void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& index) { + *buffers_.mutable_element(index) = buffer; } - // Returns the tree of indices which map to buffer pointers. - const ShapeTree& shape_index_to_buffer_entry() const { - return shape_index_to_buffer_entry_; + // Returns the underlying ShapeTree containing all the device addresses in the + // ShapedBuffer. + const ShapeTree& buffers() const { + return buffers_; } - ShapeTree* mutable_shape_index_to_buffer_entry() { - return &shape_index_to_buffer_entry_; + ShapeTree& buffers() { + return buffers_; } // Set all device memory pointers in the object to null. void clear(); - // Adds a new buffer at the given shape index. - void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index); - string ToString() const; protected: - // The shape of the device buffer with layout. - const Shape shape_; + // The shape of the data when represented on the host. + const Shape on_host_shape_; + + // The shape of the data on the device. + const Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; @@ -101,14 +100,8 @@ class ShapedBuffer { // The device the memory is allocated on. const int device_ordinal_; - // The list of DeviceMemoryBase pointers representing this shape. - // Note that there can be a many to one relationship between tuple elements - // and buffers. To account for this, shape_index_to_buffer_entry_ allows us - // to make from a position in a shape to an index into this list. - std::vector buffers_; - - // The tree of indices into buffers_. - ShapeTree shape_index_to_buffer_entry_; + // The tree of device buffers. Its shape is on_device_shape(). + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -118,17 +111,16 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Identical to ShapedBuffer::Allocate. - static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn); - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device // memory pointers in the given ShapedBuffer are set to null. static StatusOr> MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); + // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. + ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, int device_ordinal); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } @@ -143,8 +135,6 @@ class ScopedShapedBuffer : public ShapedBuffer { virtual ~ScopedShapedBuffer(); protected: - ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, - int device_ordinal); ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; void operator=(const ScopedShapedBuffer&) = delete; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index d5f53ad56f..2f36e2b16e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -40,6 +40,45 @@ TransferManager::GetPlatformTransferManagers() { return r; } +Status TransferManager::TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest) { + const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); + TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + << "On-device representation of " + << ShapeUtil::HumanString(literal.shape()) + << " is not an array: " << ShapeUtil::HumanString(on_device_shape); + if (dest.size() < GetByteSizeRequirement(on_device_shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + dest.size(), GetByteSizeRequirement(on_device_shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(dest, /*index=*/{}); + return TransferLiteralToDevice(executor, literal, shaped_buffer); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source) { + TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) + << "Shape " << ShapeUtil::HumanString(shape) + << " has a differently shaped representation on-device: " + << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); + if (source.size() < GetByteSizeRequirement(shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(source, /*index=*/{}); + return TransferLiteralFromDevice(executor, shaped_buffer); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { @@ -75,14 +114,12 @@ TransferManager::GetPlatformTransferManagers() { Status TransferManager::WriteTupleIndexTables( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " - << device_buffer.buffer(/*index=*/{}).opaque() - << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { if (ShapeUtil::IsTuple(device_subshape)) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); @@ -97,7 +134,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteTuplePointersToDevice(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(executor, elements, device_subshape, &device_memory); } @@ -143,31 +180,43 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> -TransferManager::GatherBufferPointersFromTuple( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - std::set buffer_pointers; - buffer_pointers.insert(source); - - TF_ASSIGN_OR_RETURN(std::vector tuple_elements, - ShallowCopyTupleFromDevice(executor, source, shape)); - for (auto i = 0; i < tuple_elements.size(); ++i) { - const Shape& element_shape = shape.tuple_shapes(i); - if (ShapeUtil::IsTuple(element_shape)) { - TF_ASSIGN_OR_RETURN( - std::set buffer_pointers_in_element, - GatherBufferPointersFromTuple(executor, tuple_elements[i], - element_shape)); - buffer_pointers.insert(buffer_pointers_in_element.begin(), - buffer_pointers_in_element.end()); - } else { - buffer_pointers.insert(tuple_elements[i]); - } +StatusOr> TransferManager::AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal) { + if (!LayoutUtil::HasLayout(on_host_shape)) { + return InvalidArgument( + "Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); + TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); + + auto shaped_buffer = WrapUnique(new ShapedBuffer( + on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->buffers()) { + const ShapeIndex& index = pair.first; + se::DeviceMemoryBase& memory_base = pair.second; + const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); + TF_ASSIGN_OR_RETURN(memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + GetByteSizeRequirement(subshape))); } - return std::move(buffer_pointers); + + return std::move(shaped_buffer); +} + +StatusOr> +TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); + return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index be9b769ac8..9f2b5c4aec 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -44,24 +44,6 @@ class TransferManager { // Returns the ID of the platform that this transfer manager acts on. virtual perftools::gputools::Platform::Id PlatformId() const = 0; - // Transfers the region into the provided literal using the provided - // executor. device_shape is the shape, including layout, of the data on the - // device, while literal_shape will be the shape for the literal. device_shape - // and literal_shape must be compatible, but need not have the same layout. - // TODO(b/66694934): Remove TransferLiteral* methods which accept bare - // DeviceMemoryBase. - virtual Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& region, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) = 0; - - // Transfers the given literal into the provided region output parameter, - // using the given executor. - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* region) = 0; - // Returns the shape of the on-device representation for the given shape on // the host. This is intended for use with ShapedBuffer where buffers are // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user @@ -70,37 +52,39 @@ class TransferManager { return host_shape; } - // Transfers the data held in the given ShapedBuffer into the provided literal - // using the provided executor. literal_shape will be the shape for the - // literal. The shape of the ShapedBuffer and DeviceShape(literal_shape) must - // be compatible, but need not have the same layout. + // Returns a literal containing the data held in the given ShapedBuffer. + // using the provided executor. The optional literal_shape will be the shape + // for the literal. The shape of the ShapedBuffer and + // DeviceShape(literal_shape) must be compatible, but need not have the same + // layout. virtual StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory - // represented by the given ShapedBuffer using the given executor. + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout virtual Status TransferLiteralToDevice( perftools::gputools::StreamExecutor* executor, const Literal& literal, const ShapedBuffer& device_buffer) = 0; + // Convenience methods for transferring an array to or from the device at a + // known address. This avoids having to construct a ShapedBuffer just to + // transfer an array at a known address. + Status TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest); + StatusOr> TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; - // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( @@ -112,37 +96,26 @@ class TransferManager { tensorflow::gtl::ArraySlice executor) = 0; - // Shallow copy a tuple from the device and create a DeviceMemoryBase object - // for each element in the tuple. A DeviceMemoryBase object refers to the - // buffer containing the data of that element. The DeviceMemoryBase objects - // are returned as a vector. - virtual StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) = 0; - // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer); - // Returns all buffer pointers that the tuple `source` refers to. Unlike - // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested - // tuples as well. Also, the returned DeviceMemoryBase objects are - // deduplicated. - StatusOr> - GatherBufferPointersFromTuple( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, const Shape& shape); - // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - typedef std::unique_ptr (*TransferManagerCreationFunction)(); + // Allocate a ShapedBuffer which can hold data with the given on-host + // shape. The on-device shape may be different as indicated by + // HostShapeToDeviceShape. + StatusOr> AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); + StatusOr> AllocateScopedShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); ///// // The TransferManager class also serves as a point to register objects for @@ -152,6 +125,7 @@ class TransferManager { // assumed to be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. + typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( perftools::gputools::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); @@ -162,6 +136,17 @@ class TransferManager { const perftools::gputools::Platform* platform); protected: + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // @@ -180,10 +165,9 @@ class TransferManager { const void* source, perftools::gputools::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( + // to construct a tuple index table in the platform-specific tuple + // representation. + virtual Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index bcb85b04ee..d64bf0aa5b 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -56,9 +56,13 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -XLA_TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } +XLA_TEST_F(CopyOpTest, CopyR0Bool) { + TestCopyOp(*Literal::CreateR0(true)); +} -XLA_TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } +XLA_TEST_F(CopyOpTest, CopyR1S0U32) { + TestCopyOp(*Literal::CreateR1({})); +} XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(*Literal::CreateR1({1, 2, 3})); @@ -85,7 +89,6 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); - auto constant_device_base = TransferToDevice(*literal); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -98,7 +101,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {constant_device_base}); + ExecuteAndTransfer(std::move(module), {literal.get()}); LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 8baaf39e3c..59be32a8ff 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -559,20 +559,20 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, - shape_size_fn) + auto buffer = client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer( + start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, - buffer->mutable_buffer({}))); + executors[device_ordinal], *start_indices_literal, *buffer)); std::unique_ptr executable = - client->Compile(computation, {&buffer->shape()}, ExecutableBuildOptions()) + client + ->Compile(computation, {&buffer->on_host_shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2686afccc2..a292eab1d1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -816,7 +816,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + {&buffer0->on_host_shape(), &buffer1->on_host_shape(), + &buffer2->on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index f9458f5b74..a27e0f2c10 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -111,28 +111,16 @@ std::unique_ptr HloTestBase::CreateNewModule() { return debug_options; } -StatusOr HloTestBase::Execute( +StatusOr> HloTestBase::Execute( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape) { - return test_runner_.Execute(std::move(module), arguments, result_shape); -} - -se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return test_runner_.TransferToDevice(literal).ValueOrDie(); -} - -std::unique_ptr HloTestBase::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie(); + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { - return test_runner_.ExecuteAndTransfer(std::move(module), arguments) - .ValueOrDie(); + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } StatusOr> HloTestBase::MakeReferenceModule( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 2c5ce04402..4aea9fc9fd 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -93,27 +93,14 @@ class HloTestBase : public ::testing::Test { // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); - // Executes the given module and returns a global data handle. - StatusOr Execute( + // Executes the given module and return the result as a Literal. + StatusOr> Execute( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); - - // Transfers the given literal to the device and returns the data handle. - perftools::gputools::DeviceMemoryBase TransferToDevice( - const Literal& literal); + tensorflow::gtl::ArraySlice arguments); - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - std::unique_ptr TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); - - // Executes the given module and return the result as a Literal. std::unique_ptr ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments); + tensorflow::gtl::ArraySlice arguments); // Executes the given hlo module on two backends and compares results. // diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index ad71d40197..e3298e98c6 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); std::unique_ptr result_colmaj = @@ -179,7 +179,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_colmaj), @@ -191,7 +191,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_rowmaj), @@ -213,8 +213,8 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -241,8 +241,8 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -320,8 +320,8 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, @@ -906,20 +906,18 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate( - shape, &allocator, /*device_ordinal=*/0, shape_size_fn) - .ConsumeValueOrDie(); + auto buffer = + transfer_manager + ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer->mutable_buffer({}))); + executors[device_ordinal], *literal, *buffer)); const int kWarmups = 2; - auto executable_status = client->Compile(computation, {&buffer->shape()}, - ExecutableBuildOptions()); + auto executable_status = client->Compile( + computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 062a9246e4..96b976d25d 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -188,7 +188,7 @@ LocalClientTestBase::ExecuteLocally( const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - argument_layouts[i] = &arguments[i]->shape(); + argument_layouts[i] = &arguments[i]->on_host_shape(); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 89fa6ed9f7..62d24a11fd 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -99,14 +99,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input; - input.PopulateWithValue(2.5f, {size, size}); - auto p1 = TransferToDevice(input); - auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + Literal arg1; + arg1.PopulateWithValue(2.5f, {size, size}); Literal expect; expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + auto actual = ExecuteAndTransfer( + std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } @@ -163,11 +162,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input0, input1; input0.PopulateWithValue(2.5f, {size}); input1.PopulateWithValue(1, {size}); - auto p0 = TransferToDevice(input0); - auto p1 = TransferToDevice(input1); Literal expect = *Literal::CreateR1({size * 1.5f * 3.5f}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } }; diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index f2a6474948..ed556fafb1 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -46,9 +46,10 @@ class TransferManagerTest : public LocalClientTestBase { ~TransferManagerTest() override = default; std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { - return ScopedShapedBuffer::Allocate( - shape, GetOrCreateAllocator(local_client_->platform()), - /*device_ordinal=*/0, shape_size_fn_) + return transfer_manager_ + ->AllocateScopedShapedBuffer( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0) .ValueOrDie(); } @@ -211,5 +212,39 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { LiteralTestUtil::ExpectEqual(*literal, *result); } +XLA_TEST_F(TransferManagerTest, TransferComplexValue) { + std::unique_ptr literal = Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) + .get(), + Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), + Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + } // namespace } // namespace xla -- cgit v1.2.3