aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-08-09 16:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 16:17:24 -0700
commitb306f5f9458feddbdb89b7db557cb74dc9408d07 (patch)
treec816564723de457711e9bbabe98617d3f2e121ad
parent243f6e636c93d63884c574ae9b61d397726189ed (diff)
[TF:XLA] Add a real implementation of XlaDevice::Sync() so Session::Run() will correctly wait for all computations to complete on an XLA device before termination.
[TF:XLA] Change the XlaTensor definition event to be a shared pointer to an stream_executor::Event. This allows many tensors to share the same definition event. PiperOrigin-RevId: 208128264
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_device.cc41
-rw-r--r--tensorflow/compiler/jit/xla_device.h14
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc89
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h31
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc62
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h6
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc7
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h2
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h2
12 files changed, 167 insertions, 99 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index b313d48011..669beb71ca 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -209,7 +209,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie()));
VLOG(1) << "Done";
}
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index d288d37bc7..b508021cdf 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -77,7 +77,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
- launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
+ TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
+ ctx, result, run_result.ConsumeValueOrDie()));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 4ddeaebd3e..2a2691a6a4 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -216,6 +217,8 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
+ thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
+ /*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
- TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
+ xla::StreamPool::Ptr ptr;
+ TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
+ *stream = std::shared_ptr<se::Stream>(std::move(ptr));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
- se::Stream* host_to_device_stream = stream_.get();
- se::Stream* device_to_host_stream = stream_.get();
+ std::shared_ptr<se::Stream> host_to_device_stream = stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream = stream_;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
- host_to_device_stream = host_to_device_stream_.get();
- device_to_host_stream = device_to_host_stream_.get();
+ host_to_device_stream = host_to_device_stream_;
+ device_to_host_stream = device_to_host_stream_;
}
if (!need_new_device_context) {
@@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
+ // The XlaDeviceContext keeps a reference count to the streams, and the
+ // XlaDeviceContext remains live for the duration of a Executor run. This
+ // ensures that the streams remain live for the duration of a run, even if
+ // an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
- stream_.get(), host_to_device_stream, device_to_host_stream, client(),
- transfer_as_literal_, shape_representation_fn_);
+ stream_, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
+Status XlaDevice::Sync() {
+ VLOG(1) << "XlaDevice::Sync";
+ std::shared_ptr<se::Stream> stream;
+ {
+ mutex_lock lock(mu_);
+ stream = stream_;
+ }
+ if (!stream) return Status::OK();
+
+ if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
+ return errors::Internal("XlaDevice::Sync() failed.");
+ }
+ VLOG(1) << "XlaDevice::Sync completed";
+ return Status::OK();
+}
+
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d8906419b0..dbf35f349f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
- Status Sync() override { return Status::OK(); }
+ Status Sync() override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
+
+ // Thread pool used for running closures
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0100bf51ed..0a0c089241 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include <memory>
+
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(compute_stream),
- host_to_device_stream_(host_to_device_stream),
- device_to_host_stream_(device_to_host_stream),
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : stream_(std::move(compute_stream)),
+ host_to_device_stream_(std::move(host_to_device_stream)),
+ device_to_host_stream_(std::move(device_to_host_stream)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- host_to_device_stream_->ThenWaitFor(stream_);
+ host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_, *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
- host_to_device_stream_->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(event.get());
+ xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_, shaped_buffer, literal,
+ device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback(
- [done]() { done(Status::OK()); });
+ host_to_device_stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback
+ // to avoid a deadlock. If done() is the callback that ends an
+ // Executor's run, the Executor may call XlaDevice::Sync() inside the
+ // callback. This deadlocks, because XlaDevice::Sync() waits for all
+ // stream activity to complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
return;
}
} else {
@@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
- host_to_device_stream_, block_status.error_message().c_str());
+ host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
- xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
device_to_host_stream_->ThenWaitFor(event);
- xla_tensor->SetDefinedOn(device_to_host_stream_);
+ xla_tensor->SetDefinedOn(device_to_host_stream_.get());
}
Status status;
@@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
+ "Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
@@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- device_to_device_stream->ThenWaitFor(stream_);
+ device_to_device_stream->ThenWaitFor(stream_.get());
}
}
if (se::Event* event =
- xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
device_to_device_stream->ThenWaitFor(event);
- xla_src->SetDefinedOn(device_to_device_stream);
+ xla_src->SetDefinedOn(device_to_device_stream.get());
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- CHECK(event.Init());
- device_to_device_stream->ThenRecordEvent(&event);
- xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize";
+ device_to_device_stream->ThenRecordEvent(event.get());
+ xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
- stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
+ stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback to avoid
+ // a deadlock. If done() is the callback that ends an Executor's run, the
+ // Executor may call XlaDevice::Sync() inside the callback. This
+ // deadlocks, because XlaDevice::Sync() waits for all stream activity to
+ // complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
}
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
- client, transfer_as_literal,
- std::move(shape_representation_fn)) {}
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : manager_(std::move(compute_stream), std::move(host_to_device_stream),
+ std::move(device_to_host_stream), client, transfer_as_literal,
+ std::move(shape_representation_fn), thread_pool) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 912f8d779e..2e7445340c 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -61,7 +63,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
- se::Stream* stream() const { return stream_; }
+ se::Stream* stream() const { return stream_.get(); }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -73,13 +75,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
- se::Stream* stream_;
+ std::shared_ptr<se::Stream> stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* host_to_device_stream_;
+ std::shared_ptr<se::Stream> host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* device_to_host_stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -87,6 +89,9 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+
+ // Thread pool used for running closures
+ thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -95,10 +100,12 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 6134b8c694..4efbb2d5d7 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
-void XlaComputationLaunchContext::PopulateOutputs(
+Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
+ std::shared_ptr<se::Event> definition_event;
+ if (use_multiple_streams_) {
+ definition_event = std::make_shared<se::Event>(stream->parent());
+ if (!definition_event->Init()) {
+ return errors::Internal("Failed to initialize tensor definition event.");
+ }
+ stream->ThenRecordEvent(definition_event.get());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
- OP_REQUIRES(ctx, device != nullptr,
- errors::Internal("DeviceBase was not a Device."));
+ if (device == nullptr) {
+ return errors::Internal("DeviceBase was not a Device.");
+ }
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- OP_REQUIRES(ctx,
- write.input_index >= 0 && write.input_index < ctx->num_inputs(),
- errors::Internal("Invalid input index for variable write."));
+ if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ return errors::Internal("Invalid input index for variable write.");
+ }
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
- OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index),
- &variable, [this, ctx, &write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
+ TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index), &variable,
+ [&write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
- OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
- errors::Internal("Mismatched type in variable write"));
+ if (variable->tensor()->dtype() != write.type) {
+ return errors::Internal("Mismatched type in variable write");
+ }
if (allocate_xla_tensors_) {
Tensor output_tensor;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
*variable->tensor() = output_tensor;
} else {
@@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 1ea3fa4cf2..4232f514b3 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -93,9 +93,9 @@ class XlaComputationLaunchContext {
const std::map<int, OptionalTensor>& variables);
// Given the XLA output in `output`, populate all outputs of `ctx`.
- void PopulateOutputs(OpKernelContext* ctx,
- const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ Status PopulateOutputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index d777dfa5a3..92ba7de1b7 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
mutex_lock lock(mu_);
- if (!definition_event_.has_value()) {
+ if (!definition_event_) {
return nullptr;
}
@@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
return nullptr;
}
- return &*definition_event_;
+ return definition_event_.get();
}
-void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+void XlaTensor::SetDefinedOn(se::Stream* stream,
+ std::shared_ptr<se::Event> event) {
mutex_lock lock(mu_);
definition_event_ = std::move(event);
streams_defined_on_ = {stream};
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index f7e401c731..8d36d0fa0a 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
+#include <memory>
+
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -94,7 +96,7 @@ class XlaTensor {
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
- void SetDefinedOn(se::Stream* stream, se::Event event);
+ void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
@@ -116,7 +118,7 @@ class XlaTensor {
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
- gtl::optional<se::Event> definition_event_;
+ std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index 9b109022fb..db6b910b32 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
}
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
return false;
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index 858396ef96..7ba1f18101 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
uint64 size) override;
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override;
bool SynchronousMemSet(DeviceMemoryBase *location, int value,