aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-08-10 06:21:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 06:25:33 -0700
commit56e4ea405d13125a3dcb6459019a83d12330bf84 (patch)
treef285895f73399e775479895af835b65d529d100f
parent4d9266f02cc1553e4cee9def5fee79158dd76ecd (diff)
Automated rollback of commit b306f5f9458feddbdb89b7db557cb74dc9408d07
PiperOrigin-RevId: 208200028
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_device.cc41
-rw-r--r--tensorflow/compiler/jit/xla_device.h14
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc89
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h31
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc62
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h6
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc7
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h2
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h2
12 files changed, 99 insertions, 167 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 669beb71ca..b313d48011 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -209,8 +209,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
+ launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
VLOG(1) << "Done";
}
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index b508021cdf..d288d37bc7 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -77,8 +77,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
- TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
- ctx, result, run_result.ConsumeValueOrDie()));
+ launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 2a2691a6a4..4ddeaebd3e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -217,8 +216,6 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
- thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
- /*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@@ -265,12 +262,10 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
- std::shared_ptr<se::Stream>* stream,
+ xla::StreamPool::Ptr* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
- xla::StreamPool::Ptr ptr;
- TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
- *stream = std::shared_ptr<se::Stream>(std::move(ptr));
+ TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@@ -286,8 +281,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
- std::shared_ptr<se::Stream> host_to_device_stream = stream_;
- std::shared_ptr<se::Stream> device_to_host_stream = stream_;
+ se::Stream* host_to_device_stream = stream_.get();
+ se::Stream* device_to_host_stream = stream_.get();
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@@ -295,8 +290,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
- host_to_device_stream = host_to_device_stream_;
- device_to_host_stream = device_to_host_stream_;
+ host_to_device_stream = host_to_device_stream_.get();
+ device_to_host_stream = device_to_host_stream_.get();
}
if (!need_new_device_context) {
@@ -309,13 +304,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
- // The XlaDeviceContext keeps a reference count to the streams, and the
- // XlaDeviceContext remains live for the duration of a Executor run. This
- // ensures that the streams remain live for the duration of a run, even if
- // an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
- stream_, host_to_device_stream, device_to_host_stream, client(),
- transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
+ stream_.get(), host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@@ -380,22 +371,6 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
-Status XlaDevice::Sync() {
- VLOG(1) << "XlaDevice::Sync";
- std::shared_ptr<se::Stream> stream;
- {
- mutex_lock lock(mu_);
- stream = stream_;
- }
- if (!stream) return Status::OK();
-
- if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
- return errors::Internal("XlaDevice::Sync() failed.");
- }
- VLOG(1) << "XlaDevice::Sync completed";
- return Status::OK();
-}
-
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index dbf35f349f..d8906419b0 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -123,7 +124,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
- Status Sync() override;
+ Status Sync() override { return Status::OK(); }
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -152,7 +153,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
- std::shared_ptr<se::Stream>* stream,
+ xla::StreamPool::Ptr* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@@ -173,17 +174,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
+ xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
+ xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
+ xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@@ -197,9 +198,6 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
-
- // Thread pool used for running closures
- std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0a0c089241..0100bf51ed 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -15,9 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
-#include <memory>
-
-#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -51,20 +48,17 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- std::shared_ptr<se::Stream> compute_stream,
- std::shared_ptr<se::Stream> host_to_device_stream,
- std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- thread::ThreadPool* thread_pool)
- : stream_(std::move(compute_stream)),
- host_to_device_stream_(std::move(host_to_device_stream)),
- device_to_host_stream_(std::move(device_to_host_stream)),
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+ : stream_(compute_stream),
+ host_to_device_stream_(host_to_device_stream),
+ device_to_host_stream_(device_to_host_stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)),
- thread_pool_(thread_pool) {
+ shape_representation_fn_(std::move(shape_representation_fn)) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@@ -94,15 +88,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- host_to_device_stream_->ThenWaitFor(stream_.get());
+ host_to_device_stream_->ThenWaitFor(stream_);
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_.get(), *literal, shaped_buffer));
+ host_to_device_stream_, *literal, shaped_buffer));
if (UseMultipleStreams()) {
- auto event = std::make_shared<se::Event>(stream_->parent());
- TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
- host_to_device_stream_->ThenRecordEvent(event.get());
- xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
+ se::Event event(stream_->parent());
+ TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@@ -122,7 +116,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_.get(), shaped_buffer, literal,
+ device_to_host_stream_, shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@@ -185,14 +179,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback
- // to avoid a deadlock. If done() is the callback that ends an
- // Executor's run, the Executor may call XlaDevice::Sync() inside the
- // callback. This deadlocks, because XlaDevice::Sync() waits for all
- // stream activity to complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
+ host_to_device_stream_->ThenDoHostCallback(
+ [done]() { done(Status::OK()); });
return;
}
} else {
@@ -204,7 +192,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
- host_to_device_stream_.get(), block_status.error_message().c_str());
+ host_to_device_stream_, block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@@ -237,9 +225,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
- xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
device_to_host_stream_->ThenWaitFor(event);
- xla_tensor->SetDefinedOn(device_to_host_stream_.get());
+ xla_tensor->SetDefinedOn(device_to_host_stream_);
}
Status status;
@@ -252,7 +240,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_.get(),
+ "Failed to complete data transfer on stream %p: %s", stream_,
block_status.error_message().c_str());
}
}
@@ -290,14 +278,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- device_to_device_stream->ThenWaitFor(stream_.get());
+ device_to_device_stream->ThenWaitFor(stream_);
}
}
if (se::Event* event =
- xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
+ xla_src->GetDefinitionEvent(device_to_device_stream)) {
device_to_device_stream->ThenWaitFor(event);
- xla_src->SetDefinedOn(device_to_device_stream.get());
+ xla_src->SetDefinedOn(device_to_device_stream);
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@@ -309,37 +297,28 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
- auto event = std::make_shared<se::Event>(stream_->parent());
- TF_RET_CHECK(event->Init()) << "Event failed to initialize";
- device_to_device_stream->ThenRecordEvent(event.get());
- xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
+ se::Event event(stream_->parent());
+ CHECK(event.Init());
+ device_to_device_stream->ThenRecordEvent(&event);
+ xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
- stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback to avoid
- // a deadlock. If done() is the callback that ends an Executor's run, the
- // Executor may call XlaDevice::Sync() inside the callback. This
- // deadlocks, because XlaDevice::Sync() waits for all stream activity to
- // complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
+ stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
}
}
XlaDeviceContext::XlaDeviceContext(
- std::shared_ptr<se::Stream> compute_stream,
- std::shared_ptr<se::Stream> host_to_device_stream,
- std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- thread::ThreadPool* thread_pool)
- : manager_(std::move(compute_stream), std::move(host_to_device_stream),
- std::move(device_to_host_stream), client, transfer_as_literal,
- std::move(shape_representation_fn), thread_pool) {}
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+ : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
+ client, transfer_as_literal,
+ std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 2e7445340c..912f8d779e 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,12 +47,10 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- std::shared_ptr<se::Stream> compute_stream,
- std::shared_ptr<se::Stream> host_to_device_stream,
- std::shared_ptr<se::Stream> device_to_host_stream,
- xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- thread::ThreadPool* thread_pool);
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -63,7 +61,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
- se::Stream* stream() const { return stream_.get(); }
+ se::Stream* stream() const { return stream_; }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -75,13 +73,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
- std::shared_ptr<se::Stream> stream_;
+ se::Stream* stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
- std::shared_ptr<se::Stream> host_to_device_stream_;
+ se::Stream* host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
- std::shared_ptr<se::Stream> device_to_host_stream_;
+ se::Stream* device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -89,9 +87,6 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
-
- // Thread pool used for running closures
- thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -100,12 +95,10 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- std::shared_ptr<se::Stream> compute_stream,
- std::shared_ptr<se::Stream> host_to_device_stream,
- std::shared_ptr<se::Stream> device_to_host_stream,
- xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- thread::ThreadPool* thread_pool);
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 4efbb2d5d7..6134b8c694 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -15,8 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include <memory>
-
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -184,7 +182,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
-Status XlaComputationLaunchContext::PopulateOutputs(
+void XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@@ -213,15 +211,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
- std::shared_ptr<se::Event> definition_event;
- if (use_multiple_streams_) {
- definition_event = std::make_shared<se::Event>(stream->parent());
- if (!definition_event->Init()) {
- return errors::Internal("Failed to initialize tensor definition event.");
- }
- stream->ThenRecordEvent(definition_event.get());
- }
-
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -239,13 +228,12 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
- TF_RETURN_IF_ERROR(
- ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
- if (device == nullptr) {
- return errors::Internal("DeviceBase was not a Device.");
- }
+ OP_REQUIRES(ctx, device != nullptr,
+ errors::Internal("DeviceBase was not a Device."));
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@@ -275,13 +263,16 @@ Status XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@@ -307,39 +298,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
- return errors::Internal("Invalid input index for variable write.");
- }
+ OP_REQUIRES(ctx,
+ write.input_index >= 0 && write.input_index < ctx->num_inputs(),
+ errors::Internal("Invalid input index for variable write."));
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
- TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index), &variable,
- [&write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index),
+ &variable, [this, ctx, &write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
- if (variable->tensor()->dtype() != write.type) {
- return errors::Internal("Mismatched type in variable write");
- }
+ OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
+ errors::Internal("Mismatched type in variable write"));
if (allocate_xla_tensors_) {
Tensor output_tensor;
- TF_RETURN_IF_ERROR(
- ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
}
*variable->tensor() = output_tensor;
} else {
@@ -350,7 +343,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
- return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 4232f514b3..1ea3fa4cf2 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -93,9 +93,9 @@ class XlaComputationLaunchContext {
const std::map<int, OptionalTensor>& variables);
// Given the XLA output in `output`, populate all outputs of `ctx`.
- Status PopulateOutputs(OpKernelContext* ctx,
- const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ void PopulateOutputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index 92ba7de1b7..d777dfa5a3 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
mutex_lock lock(mu_);
- if (!definition_event_) {
+ if (!definition_event_.has_value()) {
return nullptr;
}
@@ -87,11 +87,10 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
return nullptr;
}
- return definition_event_.get();
+ return &*definition_event_;
}
-void XlaTensor::SetDefinedOn(se::Stream* stream,
- std::shared_ptr<se::Event> event) {
+void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
mutex_lock lock(mu_);
definition_event_ = std::move(event);
streams_defined_on_ = {stream};
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 8d36d0fa0a..f7e401c731 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -16,8 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
-#include <memory>
-
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -96,7 +94,7 @@ class XlaTensor {
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
- void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event);
+ void SetDefinedOn(se::Stream* stream, se::Event event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
@@ -118,7 +116,7 @@ class XlaTensor {
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
- std::shared_ptr<se::Event> definition_event_;
+ gtl::optional<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index db6b910b32..9b109022fb 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
}
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return true; }
+ bool SynchronizeAllActivity() override { return false; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
return false;
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index 7ba1f18101..858396ef96 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
uint64 size) override;
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return true; }
+ bool SynchronizeAllActivity() override { return false; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override;
bool SynchronousMemSet(DeviceMemoryBase *location, int value,