aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu_transfer_manager.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu_transfer_manager.cc52
1 files changed, 38 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
index 80e918970c..d8a76443a6 100644
--- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
@@ -50,7 +50,7 @@ class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer {
int32 length() override { return length_; }
void* data() override { return buffer_; }
- void Done() override { delete this; }
+ void Done(StatusOr<Shape> /*shape*/) override { delete this; }
se::DeviceMemoryBase* device_memory() { return &device_memory_; }
@@ -65,15 +65,22 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer {
CpuOutfeedBuffer(void* destination, int32 length)
: destination_(destination), length_(length) {}
- void WaitForNotification() { return done_.WaitForNotification(); }
+ StatusOr<Shape> WaitForNotification() {
+ done_.WaitForNotification();
+ return status_;
+ }
int32 length() override { return length_; }
void* data() override { return destination_; }
- void Done() override { done_.Notify(); }
+ void Done(StatusOr<Shape> shape) override {
+ status_ = std::move(shape);
+ done_.Notify();
+ }
private:
void* destination_;
int32 length_;
+ StatusOr<Shape> status_;
tensorflow::Notification done_;
};
@@ -106,7 +113,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
buffers.reserve(literal.tuple_literals_size());
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
for (cpu::runtime::XfeedBuffer* b : buffers) {
- b->Done();
+ b->Done(ShapeUtil::MakeNil());
}
});
@@ -159,7 +166,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
/*source=*/source, queued_buffer->device_memory());
if (!s.ok()) {
- queued_buffer->Done();
+ queued_buffer->Done(ShapeUtil::MakeNil());
return s;
}
return queued_buffer;
@@ -178,8 +185,17 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
auto empty =
Literal::CreateFromDimensions(literal_shape.element_type(), dimensions);
literal->Swap(empty.get());
- return TransferBufferFromOutfeed(executor, size,
- literal->MutableInternalData());
+ TF_ASSIGN_OR_RETURN(Shape received_shape,
+ TransferBufferFromOutfeed(
+ executor, size, literal->MutableInternalData()));
+ TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
+ << "Shape received from outfeed "
+ << ShapeUtil::HumanString(received_shape)
+ << " did not match the shape that was requested for outfeed: "
+ << ShapeUtil::HumanString(literal_shape);
+ TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
+ *literal->mutable_shape() = received_shape;
+ return Status::OK();
}
if (ShapeUtil::IsNestedTuple(literal_shape)) {
@@ -199,9 +215,19 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tuple_element_shape.dimensions().size());
auto empty = Literal::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
- TF_RETURN_IF_ERROR(TransferBufferFromOutfeed(
- executor, GetByteSizeRequirement(tuple_element_shape),
- empty->MutableInternalData()));
+ TF_ASSIGN_OR_RETURN(
+ Shape received_shape,
+ TransferBufferFromOutfeed(executor,
+ GetByteSizeRequirement(tuple_element_shape),
+ empty->MutableInternalData()));
+ TF_RET_CHECK(ShapeUtil::Compatible(received_shape, tuple_element_shape))
+ << "Shape received from outfeed "
+ << ShapeUtil::HumanString(received_shape)
+ << " did not match the shape that was requested for outfeed: "
+ << ShapeUtil::HumanString(tuple_element_shape);
+ TF_RET_CHECK(GetByteSizeRequirement(tuple_element_shape) ==
+ GetByteSizeRequirement(received_shape));
+ *empty->mutable_shape() = received_shape;
elements.push_back(std::move(empty));
}
auto result = Literal::MakeTupleOwned(std::move(elements));
@@ -210,7 +236,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
return Status::OK();
}
-Status CpuTransferManager::TransferBufferFromOutfeed(
+StatusOr<Shape> CpuTransferManager::TransferBufferFromOutfeed(
perftools::gputools::StreamExecutor* executor, int64 size,
void* destination) {
if (size > std::numeric_limits<int32>::max()) {
@@ -230,9 +256,7 @@ Status CpuTransferManager::TransferBufferFromOutfeed(
<< size_32 << "B";
xfeed_manager->outfeed()->EnqueueBuffers({&buffer});
VLOG(2) << "Waiting for buffer to be notified as populated.";
- buffer.WaitForNotification();
- VLOG(2) << "Buffer is populated, returning from outfeed buffer request.";
- return Status::OK();
+ return buffer.WaitForNotification();
}
} // namespace xla