aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.cc')
-rw-r--r--tensorflow/compiler/jit/xla_device.cc41
1 files changed, 33 insertions, 8 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 4ddeaebd3e..2a2691a6a4 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -216,6 +217,8 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
+ thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
+ /*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
- TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
+ xla::StreamPool::Ptr ptr;
+ TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
+ *stream = std::shared_ptr<se::Stream>(std::move(ptr));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
- se::Stream* host_to_device_stream = stream_.get();
- se::Stream* device_to_host_stream = stream_.get();
+ std::shared_ptr<se::Stream> host_to_device_stream = stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream = stream_;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
- host_to_device_stream = host_to_device_stream_.get();
- device_to_host_stream = device_to_host_stream_.get();
+ host_to_device_stream = host_to_device_stream_;
+ device_to_host_stream = device_to_host_stream_;
}
if (!need_new_device_context) {
@@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
+ // The XlaDeviceContext keeps a reference count to the streams, and the
+ // XlaDeviceContext remains live for the duration of a Executor run. This
+ // ensures that the streams remain live for the duration of a run, even if
+ // an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
- stream_.get(), host_to_device_stream, device_to_host_stream, client(),
- transfer_as_literal_, shape_representation_fn_);
+ stream_, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
+Status XlaDevice::Sync() {
+ VLOG(1) << "XlaDevice::Sync";
+ std::shared_ptr<se::Stream> stream;
+ {
+ mutex_lock lock(mu_);
+ stream = stream_;
+ }
+ if (!stream) return Status::OK();
+
+ if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
+ return errors::Internal("XlaDevice::Sync() failed.");
+ }
+ VLOG(1) << "XlaDevice::Sync completed";
+ return Status::OK();
+}
+
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {