diff options
author | 2017-07-11 02:10:52 -0700 | |
---|---|---|
committer | 2017-07-11 02:14:44 -0700 | |
commit | 9e89636e6aa2be508fad22089c61659ce87f6e67 (patch) | |
tree | 9d42d981a2de45757bf2af6bfa96168884554d23 | |
parent | 8281e234c1dd741f30f657b66d129089e81f63e8 (diff) |
[XLA:CPU] Support for CPU outfeed and a xfeed (infeed/outfeed) test.
Note: does not yet support nested tuples, for symmetry with the current infeed
limitations.
PiperOrigin-RevId: 161502502
15 files changed, 361 insertions, 115 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 6760f72e55..0c38b6d3e7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -631,6 +631,18 @@ string Literal::ToString() const { return literal; } +/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned( + std::vector<std::unique_ptr<Literal>> elements) { + auto literal = MakeUnique<Literal>(); + std::vector<Shape> shape; + for (auto& tuple_element : elements) { + shape.push_back(tuple_element->shape()); + literal->add_tuple_literals()->Swap(tuple_element.get()); + } + *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); + return literal; +} + const void* Literal::InternalData() const { return const_cast<const void*>( const_cast<Literal*>(this)->MutableInternalData()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 8266511614..4b1464a132 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -481,6 +481,16 @@ class Literal { static std::unique_ptr<Literal> MakeTuple( tensorflow::gtl::ArraySlice<const Literal*> elements); + // As above, but intended to be invoked with move semantics; i.e. + // + // std::vector<std::unique_ptr<Literal>> elements = ...; + // auto result = Literal::MakeTupleOwned(std::move(elements)); + // + // This would have been declared as an overload, but there is ambiguity + // in invocation between the above signature and this one. + static std::unique_ptr<Literal> MakeTupleOwned( + std::vector<std::unique_ptr<Literal>> elements); + // Validates that the data payload of the literal matches the literal shape; // if it does not, an appropriate status is returned. tensorflow::Status ValidateLiteral() const; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 53410b09c8..ddf45f03f0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -326,11 +326,11 @@ cc_library( name = "cpu_runtime", srcs = [ "cpu_runtime.cc", - "infeed_manager.cc", + "xfeed_manager.cc", ], hdrs = [ "cpu_runtime.h", - "infeed_manager.h", + "xfeed_manager.h", ], copts = runtime_copts(), deps = [ @@ -416,9 +416,9 @@ cc_test( ) cc_test( - name = "infeed_manager_test", + name = "xfeed_manager_test", size = "small", - srcs = ["infeed_manager_test.cc"], + srcs = ["xfeed_manager_test.cc"], deps = [ ":cpu_runtime", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 253de20f25..40cdf079e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -24,8 +24,8 @@ namespace xla { namespace cpu { namespace runtime { -InfeedManager* GetInfeedManager() { - static InfeedManager* manager = new InfeedManager; +XfeedManager* GetXfeedManager() { + static XfeedManager* manager = new XfeedManager; return manager; } @@ -35,17 +35,36 @@ InfeedManager* GetInfeedManager() { void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( xla::int32 buffer_length) { - xla::cpu::runtime::InfeedManager* infeed = - xla::cpu::runtime::GetInfeedManager(); + VLOG(2) << "AcquireInfeedBufferForDequeue"; + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); // Wait until there's a buffer to dequeue. - xla::cpu::runtime::InfeedBuffer* buffer = infeed->BlockingDequeueBuffer(); + xla::cpu::runtime::XfeedBuffer* buffer = + xfeed->infeed()->BlockingDequeueBuffer(); CHECK_EQ(buffer->length(), buffer_length); return buffer->data(); } void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, void* buffer_ptr) { - xla::cpu::runtime::InfeedManager* infeed = - xla::cpu::runtime::GetInfeedManager(); - infeed->ReleaseCurrentBuffer(buffer_length, buffer_ptr); + VLOG(2) << "ReleaseInfeedBufferAfterDequeue"; + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr); +} + +void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + xla::int32 buffer_length) { + VLOG(2) << "AcquireOutfeedBufferForPopulation"; + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + // Wait until there's a buffer to dequeue. + xla::cpu::runtime::XfeedBuffer* buffer = + xfeed->outfeed()->BlockingDequeueBuffer(); + CHECK_EQ(buffer->length(), buffer_length); + return buffer->data(); +} + +void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + xla::int32 buffer_length, void* buffer_ptr) { + VLOG(2) << "ReleaseOutfeedBufferAfterPopulation"; + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 8eae210230..04126e062e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,7 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -54,9 +54,13 @@ constexpr char kAcquireInfeedBufferForDequeueSymbolName[] = "__xla_cpu_runtime_AcquireInfeedBufferForDequeue"; constexpr char kReleaseInfeedBufferAfterDequeueSymbolName[] = "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue"; +constexpr char kAcquireOutfeedBufferForPopulationSymbolName[] = + "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; +constexpr char kReleaseOutfeedBufferAfterPopulationSymbolName[] = + "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; // Returns the infeed manager used by the CPU runtime. -InfeedManager* GetInfeedManager(); +XfeedManager* GetXfeedManager(); } // namespace runtime } // namespace cpu @@ -86,6 +90,23 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( xla::int32 buffer_length, void* buffer_ptr); + +// Blocks until the next outfeed buffer is available to be populated, then +// returns it. +extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + xla::int32 buffer_length); + +// Relinquishes the outfeed buffer after it has been populated. +// buffer_ptr must have been previously returned by +// __xla_cpu_runtime_AcquireOutfeedBufferForPopulation. +// Once this call completes, buffer_ptr may no longer be accessed. +// buffer_length must match the length passed to the call to +// __xla_cpu_runtime_AcquireInfeedBufferForDequeue that returned +// buffer_ptr. This function must be called before the next buffer is +// acquired, i.e., there may only be one outstanding outfeed buffer in +// use by the runtime. +extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + xla::int32 buffer_length, void* buffer_ptr); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 590a23562b..077d1ee2a1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( bool is_entry_computation, std::vector<const HloInstruction*>* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix + << "]; ordered? " << (instruction_order != nullptr); num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { num_dynamic_loop_bounds_ = @@ -97,11 +98,10 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(is_entry_computation, use_rdtscp, GetProfileCountersArgument()); - if (instruction_order != nullptr) { - TF_RETURN_IF_ERROR(computation->root_instruction()->AcceptOrdered( - this, *instruction_order)); + if (instruction_order == nullptr) { + TF_RETURN_IF_ERROR(computation->Accept(this)); } else { - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } InsertOrDie(&emitted_functions_, computation, compute_function_); @@ -376,6 +376,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { const Shape& shape = infeed->shape(); + // The infeed operation produces data (dequeued from the infeed queue) at this + // address, which has been provided by buffer assignment. TF_ASSIGN_OR_RETURN(llvm::Value * target_address, EmitTargetAddressForOp(infeed)); @@ -401,8 +403,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { llvm::Value* tuple_element_address = EmitTempBufferPointer(buffer, tuple_element_shape); - TF_RETURN_IF_ERROR(EmitInfeedTransfer(ByteSizeOf(tuple_element_shape), - tuple_element_address)); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, + ByteSizeOf(tuple_element_shape), + tuple_element_address)); tuple_element_addresses.push_back(tuple_element_address); } @@ -410,7 +413,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), tuple_element_addresses, &ir_builder_); } else { - TF_RETURN_IF_ERROR(EmitInfeedTransfer(ByteSizeOf(shape), target_address)); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, ByteSizeOf(shape), + target_address)); } emitted_value_[infeed] = target_address; @@ -418,10 +422,13 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } -Status IrEmitter::EmitInfeedTransfer(int64 length, - llvm::Value* target_address) { - if (length > std::numeric_limits<int32>::max()) { - return InvalidArgument("infeed buffer length %lld is too large", length); +Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, int64 length, + llvm::Value* program_buffer_address) { + if (length <= 0 || length > std::numeric_limits<int32>::max()) { + return InvalidArgument( + "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "size range", + length); } int32 length_32 = static_cast<int32>(length); @@ -434,9 +441,14 @@ Status IrEmitter::EmitInfeedTransfer(int64 length, llvm::FunctionType::get(i8_ptr_type, {int32_type}, /*isVarArg=*/false); - llvm::Function* acquire_func = - llvm::cast<llvm::Function>(module_->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + llvm::Function* acquire_func; + if (kind == XfeedKind::kInfeed) { + acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + } else { + acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( + runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); + } acquire_func->setCallingConv(llvm::CallingConv::C); // The signature of the release infeed buffer function is: @@ -446,15 +458,28 @@ Status IrEmitter::EmitInfeedTransfer(int64 length, ir_builder_.getVoidTy(), {int32_type, i8_ptr_type}, /*isVarArg=*/false); - llvm::Function* release_func = - llvm::cast<llvm::Function>(module_->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + llvm::Function* release_func; + if (kind == XfeedKind::kInfeed) { + release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + } else { + release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( + runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); + } release_func->setCallingConv(llvm::CallingConv::C); llvm::Value* acquired_pointer = ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)}); - ir_builder_.CreateMemCpy(target_address, acquired_pointer, length_32, 1); + if (kind == XfeedKind::kInfeed) { + // Copy to the program buffer address from the acquired buffer. + ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, + length_32, 1); + } else { + // Outfeed -- copy from the in-program address to the acquired buffer. + ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, + length_32, 1); + } ir_builder_.CreateCall(release_func, {ir_builder_.getInt32(length_32), acquired_pointer}); @@ -463,13 +488,33 @@ Status IrEmitter::EmitInfeedTransfer(int64 length, } Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { - // TODO(b/34359662): Implement outfeed on CPU. - return Unimplemented("Outfeed is not supported on CPU (b/34359662)."); + HloInstruction* operand = outfeed->operands()[0]; + const Shape& operand_shape = operand->shape(); + + llvm::Value* value = GetEmittedValueFor(operand); + if (!ShapeUtil::IsTuple(operand_shape)) { + return EmitXfeedTransfer(XfeedKind::kOutfeed, ByteSizeOf(operand_shape), + value); + } + + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); + + for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(operand_shape, i); + llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( + tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), + value, &ir_builder_); + TF_RETURN_IF_ERROR(EmitXfeedTransfer( + XfeedKind::kOutfeed, ByteSizeOf(tuple_element_shape), tuple_element)); + } + + return Status::OK(); } Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not supported on GPU (b/26783907)."); + return Unimplemented("Sort is not supported on CPU (b/26783907)."); } Status IrEmitter::HandleTuple( @@ -1856,6 +1901,7 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( } Status IrEmitter::Preprocess(HloInstruction* hlo) { + VLOG(3) << "Visiting: " << hlo->ToString(); if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index dd3c55f408..4e1308384b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -110,7 +110,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* infeed) override; + Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, @@ -425,8 +425,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; - // Emit IR to transfer an infeed buffer to the target address. - Status EmitInfeedTransfer(int64 length, llvm::Value* target_address); + enum class XfeedKind { + kInfeed, + kOutfeed, + }; + + // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program + // address. + Status EmitXfeedTransfer(XfeedKind kind, int64 length, + llvm::Value* program_buffer_address); const HloModuleConfig& hlo_module_config_; diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 2ce27d22c7..0aa97b7cce 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/core/platform/logging.h" @@ -21,26 +21,28 @@ namespace xla { namespace cpu { namespace runtime { -InfeedBuffer::~InfeedBuffer() = default; - -InfeedManager::InfeedManager() : current_buffer_(nullptr) {} +void XfeedManager::Reset() { + infeed()->Reset(); + outfeed()->Reset(); +} -void InfeedManager::Reset() { +void XfeedQueueManager::Reset() { tensorflow::mutex_lock l(mu_); - CHECK(!current_buffer_); - for (auto buffer : enqueued_buffer_) { + CHECK(current_buffer_ == nullptr); + for (auto buffer : enqueued_buffers_) { buffer->Done(); } - enqueued_buffer_.clear(); + enqueued_buffers_.clear(); } -void InfeedManager::EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers) { +void XfeedQueueManager::EnqueueBuffers( + tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers) { tensorflow::mutex_lock l(mu_); - bool was_empty = enqueued_buffer_.empty(); - for (InfeedBuffer* b : buffers) { - enqueued_buffer_.push_back(b); + bool was_empty = enqueued_buffers_.empty(); + for (XfeedBuffer* b : buffers) { + enqueued_buffers_.push_back(b); } - if (was_empty) { + if (was_empty && !buffers.empty()) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems // preferable to the alternative of notifying outside the lock @@ -49,20 +51,20 @@ void InfeedManager::EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers) { } } -InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { +XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { + while (enqueued_buffers_.empty()) { cv_.wait(l); } - CHECK(!current_buffer_); - current_buffer_ = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); + CHECK(current_buffer_ == nullptr); + current_buffer_ = enqueued_buffers_.front(); + enqueued_buffers_.pop_front(); return current_buffer_; } -void InfeedManager::ReleaseCurrentBuffer(int32 length, void* data) { +void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data) { tensorflow::mutex_lock l(mu_); - CHECK(current_buffer_); + CHECK(current_buffer_ != nullptr); CHECK_EQ(length, current_buffer_->length()); CHECK_EQ(data, current_buffer_->data()); current_buffer_->Done(); diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index e965988453..efeb5eb980 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -17,13 +17,13 @@ limitations under the License. // is used by the CPU runtime to transfer buffers into an executing // CPU computation, e.g., to feed data into a while loop. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ #include <deque> -#include <vector> #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -32,41 +32,39 @@ namespace runtime { // Abstract class defining an infeed buffer that is passed to the // runtime by a client. The client manages the storage of the buffer. -class InfeedBuffer { +class XfeedBuffer { public: - virtual ~InfeedBuffer(); + virtual ~XfeedBuffer() = default; virtual int32 length() = 0; virtual void* data() = 0; virtual void Done() = 0; }; -// Client-side class used to enqueue infeed buffers. -class InfeedManager { +// Reusable component for managing the infeed and outfeed queue state. +class XfeedQueueManager { public: - InfeedManager(); + XfeedQueueManager() = default; // Calls the completion callback for any enqueued buffers that have - // not been dequeued by the runtime, and empties the infeed + // not been dequeued by the runtime, and empties the // queue. Reset may not be called while a runtime computation is // processing a dequeued buffer. The only safe way to ensure this // condition is to call Reset when no computation is taking place. void Reset(); - // Adds a set of buffers to the infeed queue - // atomically. buffer->Done will be called when the buffer will no - // longer be accessed by the InfeedManager, either as a result of a - // call to Reset or because the runtime has dequeued and used the - // buffer. - void EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers); + // Adds a sequence of buffers to the queue atomically. buffer->Done will be + // called when the buffer will no longer be accessed by the XfeedManager, + // either as a result of a call to Reset or because the runtime has dequeued + // and used the buffer. + void EnqueueBuffers(tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers); - // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Sets the current buffer to be - // the returned buffer. It is an error to call BlockingDequeueBuffer - // if there is an unreleased current buffer, i.e., - // ReleaseCurrentBuffer must be called between calls to + // Blocks until the queue is non-empty, then returns the buffer at the head of + // the queue. Sets the current buffer to be the returned buffer. It is an + // error to call BlockingDequeueBuffer if there is an unreleased current + // buffer, i.e., ReleaseCurrentBuffer must be called between calls to // BlockingDequeueBuffer. - InfeedBuffer* BlockingDequeueBuffer(); + XfeedBuffer* BlockingDequeueBuffer(); // Releases the current buffer, which is the last buffer returned by // BlockingDequeuBuffer and not yet released. length and data must @@ -76,19 +74,37 @@ class InfeedManager { private: tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is // enqueued to an empty queue. tensorflow::condition_variable cv_; - // InfeedBuffer* queue contents are not owned, but buffer->Done must + + // XfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. - std::deque<InfeedBuffer*> enqueued_buffer_; + std::deque<XfeedBuffer*> enqueued_buffers_; + // If non-NULL, the buffer that is currently being processed by the // runtime. Not owned. - InfeedBuffer* current_buffer_; + XfeedBuffer* current_buffer_ = nullptr; +}; + +// Client-side class used to enqueue infeed buffers. +class XfeedManager { + public: + XfeedManager() = default; + + void Reset(); + + XfeedQueueManager* infeed() { return &infeed_; } + XfeedQueueManager* outfeed() { return &outfeed_; } + + private: + XfeedQueueManager infeed_; + XfeedQueueManager outfeed_; }; } // namespace runtime } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index a59fa35fdb..fddb59fc2d 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include <memory> @@ -28,7 +28,7 @@ namespace { class InfeedManagerTest : public ::testing::Test {}; -class TestInfeedBuffer : public cpu::runtime::InfeedBuffer { +class TestInfeedBuffer : public cpu::runtime::XfeedBuffer { public: explicit TestInfeedBuffer(int32 length) : done_called_(false), length_(length) {} @@ -55,10 +55,10 @@ TEST_F(InfeedManagerTest, SingleThreadedSequential) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); - infeed->EnqueueBuffers({a}); - infeed->EnqueueBuffers({b}); + xfeed->infeed()->EnqueueBuffers({a}); + xfeed->infeed()->EnqueueBuffers({b}); ProcessNextBuffer(a->length()); ProcessNextBuffer(b->length()); } @@ -67,22 +67,22 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); - infeed->EnqueueBuffers({a}); + xfeed->infeed()->EnqueueBuffers({a}); ProcessNextBuffer(a->length()); - infeed->EnqueueBuffers({b}); + xfeed->infeed()->EnqueueBuffers({b}); ProcessNextBuffer(b->length()); } TEST_F(InfeedManagerTest, MultiThreaded) { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); const int32 length = 64; - pool.Schedule([infeed]() { + pool.Schedule([xfeed]() { // Spin for 100 milliseconds int64 start_micros = tensorflow::Env::Default()->NowMicros(); while (true) { @@ -92,7 +92,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { } } TestInfeedBuffer* a = new TestInfeedBuffer(length); - infeed->EnqueueBuffers({a}); + xfeed->infeed()->EnqueueBuffers({a}); }); ProcessNextBuffer(length); diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 262ba83f3d..80e918970c 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace se = ::perftools::gputools; @@ -38,7 +40,7 @@ namespace xla { namespace { -class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { +class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer { public: explicit CpuInfeedBuffer(int32 length) : length_(length), @@ -58,6 +60,23 @@ class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { se::DeviceMemoryBase device_memory_; }; +class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { + public: + CpuOutfeedBuffer(void* destination, int32 length) + : destination_(destination), length_(length) {} + + void WaitForNotification() { return done_.WaitForNotification(); } + + int32 length() override { return length_; } + void* data() override { return destination_; } + void Done() override { done_.Notify(); } + + private: + void* destination_; + int32 length_; + tensorflow::Notification done_; +}; + } // namespace CpuTransferManager::CpuTransferManager() @@ -83,10 +102,10 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. - std::vector<cpu::runtime::InfeedBuffer*> buffers; + std::vector<cpu::runtime::XfeedBuffer*> buffers; buffers.reserve(literal.tuple_literals_size()); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { - for (cpu::runtime::InfeedBuffer* b : buffers) { + for (cpu::runtime::XfeedBuffer* b : buffers) { b->Done(); } }); @@ -95,15 +114,14 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Shape& tuple_element_shape = tuple_element.shape(); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( - cpu::runtime::InfeedBuffer * buffer, + cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, tuple_element.InternalData())); buffers.push_back(buffer); } - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); - infeed_manager->EnqueueBuffers(buffers); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers(buffers); cleanup.release(); return Status::OK(); @@ -112,17 +130,16 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) { - TF_ASSIGN_OR_RETURN(cpu::runtime::InfeedBuffer * buffer, + TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); - infeed_manager->EnqueueBuffers({buffer}); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers({buffer}); return Status::OK(); } -StatusOr<cpu::runtime::InfeedBuffer*> +StatusOr<cpu::runtime::XfeedBuffer*> CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { @@ -130,8 +147,9 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); } - if (size == 0) { - return InvalidArgument("Infeed shape needs 0 bytes"); + if (size <= 0) { + return InvalidArgument("Infeed shape must have positive size; got %lld", + size); } int32 size_32 = static_cast<int32>(size); @@ -147,6 +165,76 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, return queued_buffer; } +Status CpuTransferManager::TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + if (!ShapeUtil::IsTuple(literal_shape)) { + int64 size = GetByteSizeRequirement(literal_shape); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice<int64> dimensions( + tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()), + literal_shape.dimensions().size()); + auto empty = + Literal::CreateFromDimensions(literal_shape.element_type(), dimensions); + literal->Swap(empty.get()); + return TransferBufferFromOutfeed(executor, size, + literal->MutableInternalData()); + } + + if (ShapeUtil::IsNestedTuple(literal_shape)) { + return Unimplemented( + "Nested tuple outfeeds are not yet implemented on CPU."); + } + + std::vector<std::unique_ptr<Literal>> elements; + for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(literal_shape, i); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice<int64> dimensions( + tensorflow::bit_cast<const int64*>( + tuple_element_shape.dimensions().data()), + tuple_element_shape.dimensions().size()); + auto empty = Literal::CreateFromDimensions( + tuple_element_shape.element_type(), dimensions); + TF_RETURN_IF_ERROR(TransferBufferFromOutfeed( + executor, GetByteSizeRequirement(tuple_element_shape), + empty->MutableInternalData())); + elements.push_back(std::move(empty)); + } + auto result = Literal::MakeTupleOwned(std::move(elements)); + literal->Swap(result.get()); + TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + return Status::OK(); +} + +Status CpuTransferManager::TransferBufferFromOutfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + void* destination) { + if (size > std::numeric_limits<int32>::max()) { + return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + size); + } + + if (size <= 0) { + return InvalidArgument("Outfeed shape must have positive size; got %lld", + size); + } + + int32 size_32 = static_cast<int32>(size); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + CpuOutfeedBuffer buffer(destination, size_32); + VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " + << size_32 << "B"; + xfeed_manager->outfeed()->EnqueueBuffers({&buffer}); + VLOG(2) << "Waiting for buffer to be notified as populated."; + buffer.WaitForNotification(); + VLOG(2) << "Buffer is populated, returning from outfeed buffer request."; + return Status::OK(); +} + } // namespace xla static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() { diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h index 96ffb94d71..f133a0ea49 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h @@ -18,7 +18,7 @@ limitations under the License. #include <vector> -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,13 +40,19 @@ class CpuTransferManager : public GenericTransferManager { const Literal& literal) override; Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, int64 size, const void* source) override; + Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) override; private: // Transfers infeed data to device. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. - StatusOr<cpu::runtime::InfeedBuffer*> TransferBufferToInfeedInternal( + StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal( perftools::gputools::StreamExecutor* executor, int64 size, const void* source); + Status TransferBufferFromOutfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + void* destination); TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 476b2b8d6f..cc7ff83d5e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -170,7 +170,8 @@ Status GenericTransferManager::TransferBufferToInfeed( Status GenericTransferManager::TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { - return Unimplemented("Outfeed is not supported on CPU/GPU (b/30467474)"); + return Unimplemented( + "Outfeed is not supported on this platform (b/30467474)"); } Status GenericTransferManager::ResetDevices( diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 6a5533c469..119cf7dde5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -211,7 +211,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( Status HloComputation::RemoveInstruction(HloInstruction* instruction) { VLOG(2) << "Removing instruction " << instruction->name() << " from computation " << name(); - TF_RET_CHECK(IsRemovable(instruction)); + TF_RET_CHECK(IsRemovable(instruction)) + << "cannot remove instruction: " << instruction->ToString(); TF_RET_CHECK(root_instruction() != instruction) << "cannot remove root instruction " << instruction->name(); TF_RET_CHECK(instruction->user_count() == 0) @@ -593,6 +594,12 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { unreachable_roots.push_back(instruction.get()); } } + VLOG(3) << "Unreachable roots:" + << tensorflow::str_util::Join( + unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + tensorflow::strings::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -601,6 +608,7 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { // visited root, which would invalidate iterators if the unreachable roots // weren't computed ahead of time. for (HloInstruction* root : CollectUnreachableRoots()) { + VLOG(3) << "Traversing unreachable root: " << root->ToString(); // Call FinishVisit only at the end. TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); } @@ -627,9 +635,15 @@ Status HloComputation::AcceptWithOperandOrder( Status HloComputation::AcceptOrdered( DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) const { + VLOG(3) << "Accepting visitor with order."; + for (HloInstruction* root : CollectUnreachableRoots()) { + TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) + << root->ToString(); + } TF_RET_CHECK(order.size() == instruction_count()); std::unordered_set<const HloInstruction*> visited; for (const HloInstruction* instruction : order) { + VLOG(3) << "Visiting ordered: " << instruction->ToString(); TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) << "Instruction " << instruction->name() << " is not in computation " << name(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c8bb50173c..73bae5cb31 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -907,13 +907,17 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); + case HloOpcode::kInfeed: + CHECK_EQ(new_operands.size(), 0); + return CreateInfeed(shape, infeed_config()); + case HloOpcode::kOutfeed: + CHECK_EQ(new_operands.size(), 1); + return CreateOutfeed(shape, new_operands[0], outfeed_config()); case HloOpcode::kBatchNormGrad: case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kUpdate: case HloOpcode::kIndex: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } |