diff options
author | Peter Hawkins <phawkins@google.com> | 2018-08-10 06:21:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 06:25:33 -0700 |
commit | 56e4ea405d13125a3dcb6459019a83d12330bf84 (patch) | |
tree | f285895f73399e775479895af835b65d529d100f | |
parent | 4d9266f02cc1553e4cee9def5fee79158dd76ecd (diff) |
Automated rollback of commit b306f5f9458feddbdb89b7db557cb74dc9408d07
PiperOrigin-RevId: 208200028
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_launch_op.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compile_on_demand_op.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device.cc | 41 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.cc | 89 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 62 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_tensor.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_tensor.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/interpreter/executor.h | 2 | ||||
-rw-r--r-- | tensorflow/stream_executor/host/host_gpu_executor.h | 2 |
12 files changed, 99 insertions, 167 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 669beb71ca..b313d48011 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -209,8 +209,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie())); + launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b508021cdf..d288d37bc7 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -77,8 +77,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie())); + launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2a2691a6a4..4ddeaebd3e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -217,8 +216,6 @@ XlaDevice::XlaDevice( transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name << " " << this; - thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", - /*num_threads=*/1)); } XlaDevice::~XlaDevice() { @@ -265,12 +262,10 @@ Status XlaDevice::EnsureDeviceContextOk() { Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, const string& name, - std::shared_ptr<se::Stream>* stream, + xla::StreamPool::Ptr* stream, bool* stream_was_changed) { if (!(*stream) || !(*stream)->ok()) { - xla::StreamPool::Ptr ptr; - TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); - *stream = std::shared_ptr<se::Stream>(std::move(ptr)); + TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); VLOG(1) << "XlaDevice " << this << " new " << name << " " << (*stream)->DebugStreamPointers(); *stream_was_changed = true; @@ -286,8 +281,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - std::shared_ptr<se::Stream> host_to_device_stream = stream_; - std::shared_ptr<se::Stream> device_to_host_stream = stream_; + se::Stream* host_to_device_stream = stream_.get(); + se::Stream* device_to_host_stream = stream_.get(); if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -295,8 +290,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); - host_to_device_stream = host_to_device_stream_; - device_to_host_stream = device_to_host_stream_; + host_to_device_stream = host_to_device_stream_.get(); + device_to_host_stream = device_to_host_stream_.get(); } if (!need_new_device_context) { @@ -309,13 +304,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { if (device_context_) { device_context_->Unref(); } - // The XlaDeviceContext keeps a reference count to the streams, and the - // XlaDeviceContext remains live for the duration of a Executor run. This - // ensures that the streams remain live for the duration of a run, even if - // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_, host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); + stream_.get(), host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; @@ -380,22 +371,6 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, op_kernel->ComputeAsync(context, done); } -Status XlaDevice::Sync() { - VLOG(1) << "XlaDevice::Sync"; - std::shared_ptr<se::Stream> stream; - { - mutex_lock lock(mu_); - stream = stream_; - } - if (!stream) return Status::OK(); - - if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { - return errors::Internal("XlaDevice::Sync() failed."); - } - VLOG(1) << "XlaDevice::Sync completed"; - return Status::OK(); -} - Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index dbf35f349f..d8906419b0 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -123,7 +124,7 @@ class XlaDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override; + Status Sync() override { return Status::OK(); } Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -152,7 +153,7 @@ class XlaDevice : public LocalDevice { Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, - std::shared_ptr<se::Stream>* stream, + xla::StreamPool::Ptr* stream, bool* stream_was_changed) EXCLUSIVE_LOCKS_REQUIRED(mu_); xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked() @@ -173,17 +174,17 @@ class XlaDevice : public LocalDevice { // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_); + xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_); + xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_); + xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. const bool transfer_as_literal_; @@ -197,9 +198,6 @@ class XlaDevice : public LocalDevice { // Holds extra information for GPU and TPU devices, e.g. the device context. bool use_gpu_device_info_ GUARDED_BY(mu_) = false; std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_); - - // Thread pool used for running closures - std::unique_ptr<thread::ThreadPool> thread_pool_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0a0c089241..0100bf51ed 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" -#include <memory> - -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -51,20 +48,17 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - std::shared_ptr<se::Stream> compute_stream, - std::shared_ptr<se::Stream> host_to_device_stream, - std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool) - : stream_(std::move(compute_stream)), - host_to_device_stream_(std::move(host_to_device_stream)), - device_to_host_stream_(std::move(device_to_host_stream)), + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : stream_(compute_stream), + host_to_device_stream_(host_to_device_stream), + device_to_host_stream_(device_to_host_stream), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)), - thread_pool_(thread_pool) { + shape_representation_fn_(std::move(shape_representation_fn)) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); @@ -94,15 +88,15 @@ Status XlaTransferManager::TransferLiteralToDevice( if (UseMultipleStreams()) { // Initially wait for the compute stream so that memory allocations are // synchronized. - host_to_device_stream_->ThenWaitFor(stream_.get()); + host_to_device_stream_->ThenWaitFor(stream_); } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_.get(), *literal, shaped_buffer)); + host_to_device_stream_, *literal, shaped_buffer)); if (UseMultipleStreams()) { - auto event = std::make_shared<se::Event>(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(event.get()); - xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); + se::Event event(stream_->parent()); + TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. @@ -122,7 +116,7 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_.get(), shaped_buffer, literal, + device_to_host_stream_, shaped_buffer, literal, [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { @@ -185,14 +179,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); if (status.ok()) { xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback - // to avoid a deadlock. If done() is the callback that ends an - // Executor's run, the Executor may call XlaDevice::Sync() inside the - // callback. This deadlocks, because XlaDevice::Sync() waits for all - // stream activity to complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); + host_to_device_stream_->ThenDoHostCallback( + [done]() { done(Status::OK()); }); return; } } else { @@ -204,7 +192,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (!block_status.ok()) { status = xla::InternalError( "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_.get(), block_status.error_message().c_str()); + host_to_device_stream_, block_status.error_message().c_str()); } } xla_tensor->set_host_tensor(*cpu_tensor); @@ -237,9 +225,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { + xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_.get()); + xla_tensor->SetDefinedOn(device_to_host_stream_); } Status status; @@ -252,7 +240,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status block_status = device_to_host_stream_->BlockHostUntilDone(); if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_.get(), + "Failed to complete data transfer on stream %p: %s", stream_, block_status.error_message().c_str()); } } @@ -290,14 +278,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, if (stream_ != device_to_device_stream) { // Initially wait for the compute stream so that memory allocations are // synchronized. - device_to_device_stream->ThenWaitFor(stream_.get()); + device_to_device_stream->ThenWaitFor(stream_); } } if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream.get())) { + xla_src->GetDefinitionEvent(device_to_device_stream)) { device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream.get()); + xla_src->SetDefinedOn(device_to_device_stream); } auto from_iter = xla_src->shaped_buffer().buffers().begin(); @@ -309,37 +297,28 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } if (UseMultipleStreams()) { - auto event = std::make_shared<se::Event>(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize"; - device_to_device_stream->ThenRecordEvent(event.get()); - xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); + se::Event event(stream_->parent()); + CHECK(event.Init()); + device_to_device_stream->ThenRecordEvent(&event); + xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); } return Status::OK(); }(); if (!status.ok()) { return done(status); } else { - stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback to avoid - // a deadlock. If done() is the callback that ends an Executor's run, the - // Executor may call XlaDevice::Sync() inside the callback. This - // deadlocks, because XlaDevice::Sync() waits for all stream activity to - // complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); + stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); } } XlaDeviceContext::XlaDeviceContext( - std::shared_ptr<se::Stream> compute_stream, - std::shared_ptr<se::Stream> host_to_device_stream, - std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool) - : manager_(std::move(compute_stream), std::move(host_to_device_stream), - std::move(device_to_host_stream), client, transfer_as_literal, - std::move(shape_representation_fn), thread_pool) {} + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(compute_stream, host_to_device_stream, device_to_host_stream, + client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 2e7445340c..912f8d779e 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,12 +47,10 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - std::shared_ptr<se::Stream> compute_stream, - std::shared_ptr<se::Stream> host_to_device_stream, - std::shared_ptr<se::Stream> device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool); + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -63,7 +61,7 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_.get(); } + se::Stream* stream() const { return stream_; } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -75,13 +73,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - std::shared_ptr<se::Stream> stream_; + se::Stream* stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - std::shared_ptr<se::Stream> host_to_device_stream_; + se::Stream* host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - std::shared_ptr<se::Stream> device_to_host_stream_; + se::Stream* device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -89,9 +87,6 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; - - // Thread pool used for running closures - thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -100,12 +95,10 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - std::shared_ptr<se::Stream> compute_stream, - std::shared_ptr<se::Stream> host_to_device_stream, - std::shared_ptr<se::Stream> device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool); + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 4efbb2d5d7..6134b8c694 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" -#include <memory> - #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -184,7 +182,7 @@ void XlaComputationLaunchContext::PopulateInputs( } } -Status XlaComputationLaunchContext::PopulateOutputs( +void XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, ScopedShapedBuffer output) { se::Stream* stream = @@ -213,15 +211,6 @@ Status XlaComputationLaunchContext::PopulateOutputs( output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); } - std::shared_ptr<se::Event> definition_event; - if (use_multiple_streams_) { - definition_event = std::make_shared<se::Event>(stream->parent()); - if (!definition_event->Init()) { - return errors::Internal("Failed to initialize tensor definition event."); - } - stream->ThenRecordEvent(definition_event.get()); - } - // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -239,13 +228,12 @@ Status XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - TF_RETURN_IF_ERROR( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast<Device*>(ctx->device()); - if (device == nullptr) { - return errors::Internal("DeviceBase was not a Device."); - } + OP_REQUIRES(ctx, device != nullptr, + errors::Internal("DeviceBase was not a Device.")); ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -275,13 +263,16 @@ Status XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); } } else { // xla_tensor wasn't valid, which must mean this is a zero-element @@ -307,39 +298,41 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { - return errors::Internal("Invalid input index for variable write."); - } + OP_REQUIRES(ctx, + write.input_index >= 0 && write.input_index < ctx->num_inputs(), + errors::Internal("Invalid input index for variable write.")); se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>( - ctx, HandleFromInput(ctx, write.input_index), &variable, - [&write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>( + ctx, HandleFromInput(ctx, write.input_index), + &variable, [this, ctx, &write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - if (variable->tensor()->dtype() != write.type) { - return errors::Internal("Mismatched type in variable write"); - } + OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, + errors::Internal("Mismatched type in variable write")); if (allocate_xla_tensors_) { Tensor output_tensor; - TF_RETURN_IF_ERROR( - ctx->allocate_temp(write.type, write.shape, &output_tensor)); + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); } *variable->tensor() = output_tensor; } else { @@ -350,7 +343,6 @@ Status XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } - return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4232f514b3..1ea3fa4cf2 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -93,9 +93,9 @@ class XlaComputationLaunchContext { const std::map<int, OptionalTensor>& variables); // Given the XLA output in `output`, populate all outputs of `ctx`. - Status PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + void PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 92ba7de1b7..d777dfa5a3 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { mutex_lock lock(mu_); - if (!definition_event_) { + if (!definition_event_.has_value()) { return nullptr; } @@ -87,11 +87,10 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { return nullptr; } - return definition_event_.get(); + return &*definition_event_; } -void XlaTensor::SetDefinedOn(se::Stream* stream, - std::shared_ptr<se::Event> event) { +void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 8d36d0fa0a..f7e401c731 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ -#include <memory> - #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -96,7 +94,7 @@ class XlaTensor { // Assert that the tensor's content is defined on 'stream' by the time 'event' // triggers. - void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event); + void SetDefinedOn(se::Stream* stream, se::Event event); // Assert that the tensor's content is defined on 'stream'. This version does // not provide an event, and must be called *after* SetDefinedOn(Stream, @@ -118,7 +116,7 @@ class XlaTensor { // An optional event that is triggered when the tensor's content has been // defined. If this event is nullptr, it is assumed that the tensor's content // is always defined. - std::shared_ptr<se::Event> definition_event_; + gtl::optional<se::Event> definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index db6b910b32..9b109022fb 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return true; } + bool SynchronizeAllActivity() override { return false; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index 7ba1f18101..858396ef96 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface { uint64 size) override; // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return true; } + bool SynchronizeAllActivity() override { return false; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override; bool SynchronousMemSet(DeviceMemoryBase *location, int value, |