diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu_transfer_manager.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu_transfer_manager.cc | 116 |
1 files changed, 102 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 262ba83f3d..80e918970c 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace se = ::perftools::gputools; @@ -38,7 +40,7 @@ namespace xla { namespace { -class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { +class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer { public: explicit CpuInfeedBuffer(int32 length) : length_(length), @@ -58,6 +60,23 @@ class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { se::DeviceMemoryBase device_memory_; }; +class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { + public: + CpuOutfeedBuffer(void* destination, int32 length) + : destination_(destination), length_(length) {} + + void WaitForNotification() { return done_.WaitForNotification(); } + + int32 length() override { return length_; } + void* data() override { return destination_; } + void Done() override { done_.Notify(); } + + private: + void* destination_; + int32 length_; + tensorflow::Notification done_; +}; + } // namespace CpuTransferManager::CpuTransferManager() @@ -83,10 +102,10 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. - std::vector<cpu::runtime::InfeedBuffer*> buffers; + std::vector<cpu::runtime::XfeedBuffer*> buffers; buffers.reserve(literal.tuple_literals_size()); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { - for (cpu::runtime::InfeedBuffer* b : buffers) { + for (cpu::runtime::XfeedBuffer* b : buffers) { b->Done(); } }); @@ -95,15 +114,14 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Shape& tuple_element_shape = tuple_element.shape(); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( - cpu::runtime::InfeedBuffer * buffer, + cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, tuple_element.InternalData())); buffers.push_back(buffer); } - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); - infeed_manager->EnqueueBuffers(buffers); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers(buffers); cleanup.release(); return Status::OK(); @@ -112,17 +130,16 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) { - TF_ASSIGN_OR_RETURN(cpu::runtime::InfeedBuffer * buffer, + TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); - infeed_manager->EnqueueBuffers({buffer}); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers({buffer}); return Status::OK(); } -StatusOr<cpu::runtime::InfeedBuffer*> +StatusOr<cpu::runtime::XfeedBuffer*> CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { @@ -130,8 +147,9 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); } - if (size == 0) { - return InvalidArgument("Infeed shape needs 0 bytes"); + if (size <= 0) { + return InvalidArgument("Infeed shape must have positive size; got %lld", + size); } int32 size_32 = static_cast<int32>(size); @@ -147,6 +165,76 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, return queued_buffer; } +Status CpuTransferManager::TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + if (!ShapeUtil::IsTuple(literal_shape)) { + int64 size = GetByteSizeRequirement(literal_shape); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice<int64> dimensions( + tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()), + literal_shape.dimensions().size()); + auto empty = + Literal::CreateFromDimensions(literal_shape.element_type(), dimensions); + literal->Swap(empty.get()); + return TransferBufferFromOutfeed(executor, size, + literal->MutableInternalData()); + } + + if (ShapeUtil::IsNestedTuple(literal_shape)) { + return Unimplemented( + "Nested tuple outfeeds are not yet implemented on CPU."); + } + + std::vector<std::unique_ptr<Literal>> elements; + for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(literal_shape, i); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice<int64> dimensions( + tensorflow::bit_cast<const int64*>( + tuple_element_shape.dimensions().data()), + tuple_element_shape.dimensions().size()); + auto empty = Literal::CreateFromDimensions( + tuple_element_shape.element_type(), dimensions); + TF_RETURN_IF_ERROR(TransferBufferFromOutfeed( + executor, GetByteSizeRequirement(tuple_element_shape), + empty->MutableInternalData())); + elements.push_back(std::move(empty)); + } + auto result = Literal::MakeTupleOwned(std::move(elements)); + literal->Swap(result.get()); + TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + return Status::OK(); +} + +Status CpuTransferManager::TransferBufferFromOutfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + void* destination) { + if (size > std::numeric_limits<int32>::max()) { + return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + size); + } + + if (size <= 0) { + return InvalidArgument("Outfeed shape must have positive size; got %lld", + size); + } + + int32 size_32 = static_cast<int32>(size); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + CpuOutfeedBuffer buffer(destination, size_32); + VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " + << size_32 << "B"; + xfeed_manager->outfeed()->EnqueueBuffers({&buffer}); + VLOG(2) << "Waiting for buffer to be notified as populated."; + buffer.WaitForNotification(); + VLOG(2) << "Buffer is populated, returning from outfeed buffer request."; + return Status::OK(); +} + } // namespace xla static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() { |