aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-13 18:06:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-13 18:11:30 -0700
commitea125c27974135fbad6bcb75b720499c68d52357 (patch)
tree14c0ce6c0bf44c23cd714400077cd9ee3cdc0860
parent00fda25cf0c51e9e67a93d8ae25f65c363e2a199 (diff)
[XLA] Pass shape/layout information in calls to the CPU runtime routines.
Previously the CPU runtime wouldn't know how the data that was being outfed was laid out by the XLA LayoutAssignment pass, which could result in transposed-value results. This also allows us to validate the contract between the host program and the compiled XLA program with (reified) runtime type checks. PiperOrigin-RevId: 161895093
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc71
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h19
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc47
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.h16
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/cpu_transfer_manager.cc52
-rw-r--r--tensorflow/compiler/xla/service/cpu_transfer_manager.h4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc19
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h13
13 files changed, 242 insertions, 65 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 6c2deb8c9a..7248cb5f4c 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -335,7 +335,11 @@ cc_library(
],
copts = runtime_copts(),
deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 40cdf079e3..07a9832867 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,8 +17,10 @@ limitations under the License.
#include <functional>
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
namespace xla {
namespace cpu {
@@ -33,38 +35,79 @@ XfeedManager* GetXfeedManager() {
} // namespace cpu
} // namespace xla
-void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length) {
- VLOG(2) << "AcquireInfeedBufferForDequeue";
+namespace {
+
+tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
+ xla::StatusOr<xla::Shape> shape =
+ xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
+ if (shape.ok()) {
+ return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
+ }
+ return "<invalid shape>";
+}
+
+} // namespace
+
+void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
+ const void* shape,
+ xla::int32 shape_length) {
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "AcquireInfeedBufferForDequeue: "
+ << ShapeString(shape, shape_length);
+ }
xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->infeed()->BlockingDequeueBuffer();
- CHECK_EQ(buffer->length(), buffer_length);
+ CHECK_EQ(buffer->length(), buffer_length)
+ << "XLA program infeed request buffer size " << buffer_length
+ << " did not match the runtime's infed buffer length " << buffer->length()
+ << "; program reports desired shape: "
+ << ShapeString(shape, shape_length);
return buffer->data();
}
-void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
- void* buffer_ptr) {
- VLOG(2) << "ReleaseInfeedBufferAfterDequeue";
+void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
+ xla::int32 shape_length) {
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
+ << ShapeString(shape_ptr, shape_length);
+ }
xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
- xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr);
+ xla::StatusOr<xla::Shape> shape =
+ xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
+ xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
+ std::move(shape));
}
void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length) {
- VLOG(2) << "AcquireOutfeedBufferForPopulation";
+ xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) {
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
+ << ShapeString(shape_ptr, shape_length);
+ }
xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->outfeed()->BlockingDequeueBuffer();
- CHECK_EQ(buffer->length(), buffer_length);
+ CHECK_EQ(buffer->length(), buffer_length)
+ << "XLA program outfeed request buffer size " << buffer_length
+ << " did not match the runtime's outfeed buffer length "
+ << buffer->length() << "; program reports outfed shape: "
+ << ShapeString(shape_ptr, shape_length);
return buffer->data();
}
void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr) {
- VLOG(2) << "ReleaseOutfeedBufferAfterPopulation";
+ xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
+ xla::int32 shape_length) {
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
+ << ShapeString(shape_ptr, shape_length);
+ }
xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
- xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr);
+ xla::StatusOr<xla::Shape> shape =
+ xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
+ xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, shape);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 04126e062e..40a7f548a2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -68,13 +68,19 @@ XfeedManager* GetXfeedManager();
extern "C" {
+// Note: in the runtime entry points below, the shape pointer and shape_length
+// reflect values that can be deserialized via
+// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified
+// type information from the generated program to the runtime, which helps check
+// the type safety and contract for the emitted-code/runtime communication.
+
// Blocks until the next infeed buffer is ready to be dequeued, then
// returns it. Fails catastrophically if the next enqueued buffer is
// not of the correct length in bytes. Checking the shape rather than
// the length would be more exact, but the length check is chosen as a
// tradeoff between error checking and speed/simplicity.
extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length);
+ xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
// Relinquishes the next infeed buffer that was returned by
// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
@@ -89,12 +95,13 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
// implemented we will add support for multiple outstanding buffers
// that can be returned out of order.
extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr);
+ xla::int32 buffer_length, void* buffer_ptr, const void* shape,
+ xla::int32 shape_length);
// Blocks until the next outfeed buffer is available to be populated, then
// returns it.
extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length);
+ xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
// Relinquishes the outfeed buffer after it has been populated.
// buffer_ptr must have been previously returned by
@@ -106,7 +113,9 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
// acquired, i.e., there may only be one outstanding outfeed buffer in
// use by the runtime.
extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr);
-}
+ xla::int32 buffer_length, void* buffer_ptr, const void* shape,
+ xla::int32 shape_length);
+
+} // extern "C"
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 077d1ee2a1..9db1614cd7 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -403,9 +403,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
llvm::Value* tuple_element_address =
EmitTempBufferPointer(buffer, tuple_element_shape);
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed,
- ByteSizeOf(tuple_element_shape),
- tuple_element_address));
+ TF_RETURN_IF_ERROR(EmitXfeedTransfer(
+ XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
tuple_element_addresses.push_back(tuple_element_address);
}
@@ -413,8 +412,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape),
tuple_element_addresses, &ir_builder_);
} else {
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, ByteSizeOf(shape),
- target_address));
+ TF_RETURN_IF_ERROR(
+ EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address));
}
emitted_value_[infeed] = target_address;
@@ -422,8 +421,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
return Status::OK();
}
-Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length,
+Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address) {
+ int64 length = ByteSizeOf(shape);
if (length <= 0 || length > std::numeric_limits<int32>::max()) {
return InvalidArgument(
"xfeed (infeed or outfeed) buffer length %lld is outside the valid "
@@ -432,14 +432,19 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length,
}
int32 length_32 = static_cast<int32>(length);
+ int32 shape_length;
+ TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
+ llvm_ir::EncodeSelfDescribingShapeConstant(
+ shape, &shape_length, &ir_builder_));
+
// The signature of the acquire infeed buffer function is:
//
// (void*)(int32 length);
- llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::Type* int32_type = ir_builder_.getInt32Ty();
- llvm::FunctionType* acquire_type =
- llvm::FunctionType::get(i8_ptr_type, {int32_type},
- /*isVarArg=*/false);
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+ llvm::FunctionType* acquire_type = llvm::FunctionType::get(
+ i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
+ /*isVarArg=*/false);
llvm::Function* acquire_func;
if (kind == XfeedKind::kInfeed) {
@@ -455,7 +460,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length,
//
// (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- ir_builder_.getVoidTy(), {int32_type, i8_ptr_type},
+ ir_builder_.getVoidTy(),
+ {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -468,8 +474,13 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length,
}
release_func->setCallingConv(llvm::CallingConv::C);
- llvm::Value* acquired_pointer =
- ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)});
+ // Implementation note: this call informs the runtime that it wants a buffer
+ // of size exactly 'length_32', and the runtime is responsible for
+ // check-failing the process if there is a mismatch, versus passing us back a
+ // buffer that we might overrun.
+ llvm::Value* acquired_pointer = ir_builder_.CreateCall(
+ acquire_func, {ir_builder_.getInt32(length_32), shape_ptr,
+ ir_builder_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
@@ -482,7 +493,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length,
}
ir_builder_.CreateCall(release_func,
- {ir_builder_.getInt32(length_32), acquired_pointer});
+ {ir_builder_.getInt32(length_32), acquired_pointer,
+ shape_ptr, ir_builder_.getInt32(shape_length)});
return Status::OK();
}
@@ -493,8 +505,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
llvm::Value* value = GetEmittedValueFor(operand);
if (!ShapeUtil::IsTuple(operand_shape)) {
- return EmitXfeedTransfer(XfeedKind::kOutfeed, ByteSizeOf(operand_shape),
- value);
+ return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
}
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
@@ -505,8 +516,8 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
value, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(
- XfeedKind::kOutfeed, ByteSizeOf(tuple_element_shape), tuple_element));
+ TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
+ tuple_element_shape, tuple_element));
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 4e1308384b..1d9c5a6dd7 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -432,7 +432,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
// address.
- Status EmitXfeedTransfer(XfeedKind kind, int64 length,
+ Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address);
const HloModuleConfig& hlo_module_config_;
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
index 0aa97b7cce..2160c3cd01 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -30,7 +31,7 @@ void XfeedQueueManager::Reset() {
tensorflow::mutex_lock l(mu_);
CHECK(current_buffer_ == nullptr);
for (auto buffer : enqueued_buffers_) {
- buffer->Done();
+ buffer->Done(ShapeUtil::MakeNil());
}
enqueued_buffers_.clear();
}
@@ -62,12 +63,13 @@ XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() {
return current_buffer_;
}
-void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data) {
+void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data,
+ StatusOr<Shape> shape) {
tensorflow::mutex_lock l(mu_);
CHECK(current_buffer_ != nullptr);
CHECK_EQ(length, current_buffer_->length());
CHECK_EQ(data, current_buffer_->data());
- current_buffer_->Done();
+ current_buffer_->Done(std::move(shape));
current_buffer_ = nullptr;
}
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
index efeb5eb980..86af789384 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
@@ -22,7 +22,9 @@ limitations under the License.
#include <deque>
+#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
@@ -38,7 +40,11 @@ class XfeedBuffer {
virtual int32 length() = 0;
virtual void* data() = 0;
- virtual void Done() = 0;
+
+ // The 'shape' parameter reflects what shape the embedded program was
+ // expecting / producing with respect to this XfeedBuffer. E.g. this will
+ // contain information about the layout of an outfed buffer.
+ virtual void Done(StatusOr<Shape> shape) = 0;
};
// Reusable component for managing the infeed and outfeed queue state.
@@ -70,7 +76,13 @@ class XfeedQueueManager {
// BlockingDequeuBuffer and not yet released. length and data must
// match the buffer->length() and buffer->data() for the current
// buffer.
- void ReleaseCurrentBuffer(int32 length, void* data);
+ //
+ // 'shape' communicates the shape of the buffer being released. If the program
+ // passed a value that could not be decoded as a shape, 'shape' will be an
+ // error status. In the case of outfeed, this indicates the layout of the
+ // shape that has been outfed. In the case of infeed, this can be used for
+ // sanity checking purposes.
+ void ReleaseCurrentBuffer(int32 length, void* data, StatusOr<Shape> shape);
private:
tensorflow::mutex mu_;
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
index fddb59fc2d..8defd28b01 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,25 +32,53 @@ class InfeedManagerTest : public ::testing::Test {};
class TestInfeedBuffer : public cpu::runtime::XfeedBuffer {
public:
- explicit TestInfeedBuffer(int32 length)
- : done_called_(false), length_(length) {}
+ explicit TestInfeedBuffer(int32 length, bool expect_shape_match = true)
+ : shape_(ShapeUtil::MakeShape(U8, {length})),
+ done_called_(false),
+ length_(length),
+ expect_shape_match_(expect_shape_match) {}
~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); }
int32 length() override { return length_; }
void* data() override { return nullptr; }
- void Done() override {
+ void Done(StatusOr<Shape> shape) override {
CHECK(!done_called_);
done_called_ = true;
+ TF_ASSERT_OK(shape.status());
+ EXPECT_EQ(expect_shape_match_, ShapeUtil::Equal(shape_, shape.ValueOrDie()))
+ << "want " << ShapeUtil::HumanString(shape_) << " "
+ << (expect_shape_match_ ? "==" : "!=") << " "
+ << ShapeUtil::HumanString(shape.ValueOrDie());
}
+ const Shape& shape() const { return shape_; }
+
private:
+ Shape shape_;
bool done_called_;
int32 length_;
+ bool expect_shape_match_;
};
+// Performs the acquire/release sequence on the infeed, as the generated CPU
+// code would in the process of executing the infeed operation.
void ProcessNextBuffer(int32 length) {
- void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(length);
- __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer);
+ auto shape = ShapeUtil::MakeShape(U8, {length});
+ string bytes = shape.SerializeAsString();
+ void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
+ bytes.data(), bytes.size());
+}
+
+// Performs the acquire/release sequence on the outfeed, as the generated CPU
+// code would in the process of executing the outfeed operation.
+void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
+ string bytes = shape.SerializeAsString();
+ void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
+ length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
+ length, buffer, bytes.data(), bytes.size());
}
TEST_F(InfeedManagerTest, SingleThreadedSequential) {
@@ -98,5 +128,13 @@ TEST_F(InfeedManagerTest, MultiThreaded) {
ProcessNextBuffer(length);
}
+TEST_F(InfeedManagerTest, OutfeedWrongShape) {
+ TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ xfeed->outfeed()->EnqueueBuffers({b});
+
+ ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
index 80e918970c..d8a76443a6 100644
--- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
@@ -50,7 +50,7 @@ class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer {
int32 length() override { return length_; }
void* data() override { return buffer_; }
- void Done() override { delete this; }
+ void Done(StatusOr<Shape> /*shape*/) override { delete this; }
se::DeviceMemoryBase* device_memory() { return &device_memory_; }
@@ -65,15 +65,22 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer {
CpuOutfeedBuffer(void* destination, int32 length)
: destination_(destination), length_(length) {}
- void WaitForNotification() { return done_.WaitForNotification(); }
+ StatusOr<Shape> WaitForNotification() {
+ done_.WaitForNotification();
+ return status_;
+ }
int32 length() override { return length_; }
void* data() override { return destination_; }
- void Done() override { done_.Notify(); }
+ void Done(StatusOr<Shape> shape) override {
+ status_ = std::move(shape);
+ done_.Notify();
+ }
private:
void* destination_;
int32 length_;
+ StatusOr<Shape> status_;
tensorflow::Notification done_;
};
@@ -106,7 +113,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
buffers.reserve(literal.tuple_literals_size());
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
for (cpu::runtime::XfeedBuffer* b : buffers) {
- b->Done();
+ b->Done(ShapeUtil::MakeNil());
}
});
@@ -159,7 +166,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
/*source=*/source, queued_buffer->device_memory());
if (!s.ok()) {
- queued_buffer->Done();
+ queued_buffer->Done(ShapeUtil::MakeNil());
return s;
}
return queued_buffer;
@@ -178,8 +185,17 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
auto empty =
Literal::CreateFromDimensions(literal_shape.element_type(), dimensions);
literal->Swap(empty.get());
- return TransferBufferFromOutfeed(executor, size,
- literal->MutableInternalData());
+ TF_ASSIGN_OR_RETURN(Shape received_shape,
+ TransferBufferFromOutfeed(
+ executor, size, literal->MutableInternalData()));
+ TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
+ << "Shape received from outfeed "
+ << ShapeUtil::HumanString(received_shape)
+ << " did not match the shape that was requested for outfeed: "
+ << ShapeUtil::HumanString(literal_shape);
+ TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
+ *literal->mutable_shape() = received_shape;
+ return Status::OK();
}
if (ShapeUtil::IsNestedTuple(literal_shape)) {
@@ -199,9 +215,19 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tuple_element_shape.dimensions().size());
auto empty = Literal::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
- TF_RETURN_IF_ERROR(TransferBufferFromOutfeed(
- executor, GetByteSizeRequirement(tuple_element_shape),
- empty->MutableInternalData()));
+ TF_ASSIGN_OR_RETURN(
+ Shape received_shape,
+ TransferBufferFromOutfeed(executor,
+ GetByteSizeRequirement(tuple_element_shape),
+ empty->MutableInternalData()));
+ TF_RET_CHECK(ShapeUtil::Compatible(received_shape, tuple_element_shape))
+ << "Shape received from outfeed "
+ << ShapeUtil::HumanString(received_shape)
+ << " did not match the shape that was requested for outfeed: "
+ << ShapeUtil::HumanString(tuple_element_shape);
+ TF_RET_CHECK(GetByteSizeRequirement(tuple_element_shape) ==
+ GetByteSizeRequirement(received_shape));
+ *empty->mutable_shape() = received_shape;
elements.push_back(std::move(empty));
}
auto result = Literal::MakeTupleOwned(std::move(elements));
@@ -210,7 +236,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
return Status::OK();
}
-Status CpuTransferManager::TransferBufferFromOutfeed(
+StatusOr<Shape> CpuTransferManager::TransferBufferFromOutfeed(
perftools::gputools::StreamExecutor* executor, int64 size,
void* destination) {
if (size > std::numeric_limits<int32>::max()) {
@@ -230,9 +256,7 @@ Status CpuTransferManager::TransferBufferFromOutfeed(
<< size_32 << "B";
xfeed_manager->outfeed()->EnqueueBuffers({&buffer});
VLOG(2) << "Waiting for buffer to be notified as populated.";
- buffer.WaitForNotification();
- VLOG(2) << "Buffer is populated, returning from outfeed buffer request.";
- return Status::OK();
+ return buffer.WaitForNotification();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h
index f133a0ea49..30dc2d9062 100644
--- a/tensorflow/compiler/xla/service/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h
@@ -50,7 +50,9 @@ class CpuTransferManager : public GenericTransferManager {
StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal(
perftools::gputools::StreamExecutor* executor, int64 size,
const void* source);
- Status TransferBufferFromOutfeed(
+
+ // On success, returns the shape that was transferred from the outfeed.
+ StatusOr<Shape> TransferBufferFromOutfeed(
perftools::gputools::StreamExecutor* executor, int64 size,
void* destination);
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index cc7ff83d5e..69195c45ed 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -176,9 +176,9 @@ Status GenericTransferManager::TransferLiteralFromOutfeed(
Status GenericTransferManager::ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
- executors) {
+ /*executors*/) {
return Unimplemented(
- "Device reset is not yet supported on CPU and GPU (b/30481585)");
+ "Device reset is not yet supported on this platform (b/30481585)");
}
int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index e348511c62..a8c17a67f1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -136,6 +137,24 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
return result_type;
}
+StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(
+ const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) {
+ string encoded_shape = shape.SerializeAsString();
+ if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
+ return InternalError("Encoded shape size exceeded int32 size limit.");
+ }
+ *shape_size = static_cast<int32>(encoded_shape.size());
+ return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape));
+}
+
+StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
+ int32 size_bytes) {
+ Shape shape;
+ TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes));
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
+ return shape;
+}
+
namespace {
// Recursively construct a multidimensional LLVM constant which represents the
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 7b09c1f831..d940c3fcbc 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -106,6 +106,19 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder);
+// Returns a value that represents a pointer to a global string constant that
+// encodes the shape as a serialized protobuf.
+StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(
+ const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder);
+
+// Inverses the encoding of a Shape protobuf into an LLVM global variable.
+//
+// This is intended to be called from the runtime to decode the llvm::Constants
+// that are created via ConvertShapeToSelfDescribingConstant and subsequently
+// embedded into the program.
+StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
+ int32 size_bytes);
+
// Converts a given literal to an IR Constant. Literals have known constant
// values at IR emission time.
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,