diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu_transfer_manager.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu_transfer_manager.cc | 52 |
1 files changed, 38 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 80e918970c..d8a76443a6 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -50,7 +50,7 @@ class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer { int32 length() override { return length_; } void* data() override { return buffer_; } - void Done() override { delete this; } + void Done(StatusOr<Shape> /*shape*/) override { delete this; } se::DeviceMemoryBase* device_memory() { return &device_memory_; } @@ -65,15 +65,22 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { CpuOutfeedBuffer(void* destination, int32 length) : destination_(destination), length_(length) {} - void WaitForNotification() { return done_.WaitForNotification(); } + StatusOr<Shape> WaitForNotification() { + done_.WaitForNotification(); + return status_; + } int32 length() override { return length_; } void* data() override { return destination_; } - void Done() override { done_.Notify(); } + void Done(StatusOr<Shape> shape) override { + status_ = std::move(shape); + done_.Notify(); + } private: void* destination_; int32 length_; + StatusOr<Shape> status_; tensorflow::Notification done_; }; @@ -106,7 +113,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, buffers.reserve(literal.tuple_literals_size()); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (cpu::runtime::XfeedBuffer* b : buffers) { - b->Done(); + b->Done(ShapeUtil::MakeNil()); } }); @@ -159,7 +166,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, /*source=*/source, queued_buffer->device_memory()); if (!s.ok()) { - queued_buffer->Done(); + queued_buffer->Done(ShapeUtil::MakeNil()); return s; } return queued_buffer; @@ -178,8 +185,17 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( auto empty = Literal::CreateFromDimensions(literal_shape.element_type(), dimensions); literal->Swap(empty.get()); - return TransferBufferFromOutfeed(executor, size, - literal->MutableInternalData()); + TF_ASSIGN_OR_RETURN(Shape received_shape, + TransferBufferFromOutfeed( + executor, size, literal->MutableInternalData())); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + << "Shape received from outfeed " + << ShapeUtil::HumanString(received_shape) + << " did not match the shape that was requested for outfeed: " + << ShapeUtil::HumanString(literal_shape); + TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); + *literal->mutable_shape() = received_shape; + return Status::OK(); } if (ShapeUtil::IsNestedTuple(literal_shape)) { @@ -199,9 +215,19 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tuple_element_shape.dimensions().size()); auto empty = Literal::CreateFromDimensions( tuple_element_shape.element_type(), dimensions); - TF_RETURN_IF_ERROR(TransferBufferFromOutfeed( - executor, GetByteSizeRequirement(tuple_element_shape), - empty->MutableInternalData())); + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferBufferFromOutfeed(executor, + GetByteSizeRequirement(tuple_element_shape), + empty->MutableInternalData())); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, tuple_element_shape)) + << "Shape received from outfeed " + << ShapeUtil::HumanString(received_shape) + << " did not match the shape that was requested for outfeed: " + << ShapeUtil::HumanString(tuple_element_shape); + TF_RET_CHECK(GetByteSizeRequirement(tuple_element_shape) == + GetByteSizeRequirement(received_shape)); + *empty->mutable_shape() = received_shape; elements.push_back(std::move(empty)); } auto result = Literal::MakeTupleOwned(std::move(elements)); @@ -210,7 +236,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( return Status::OK(); } -Status CpuTransferManager::TransferBufferFromOutfeed( +StatusOr<Shape> CpuTransferManager::TransferBufferFromOutfeed( perftools::gputools::StreamExecutor* executor, int64 size, void* destination) { if (size > std::numeric_limits<int32>::max()) { @@ -230,9 +256,7 @@ Status CpuTransferManager::TransferBufferFromOutfeed( << size_32 << "B"; xfeed_manager->outfeed()->EnqueueBuffers({&buffer}); VLOG(2) << "Waiting for buffer to be notified as populated."; - buffer.WaitForNotification(); - VLOG(2) << "Buffer is populated, returning from outfeed buffer request."; - return Status::OK(); + return buffer.WaitForNotification(); } } // namespace xla |