aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.cc')
-rw-r--r--tensorflow/compiler/jit/xla_device.cc70
1 files changed, 55 insertions, 15 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index ed007d603e..c55eba2f79 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
+ bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
@@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
- device->reset(new XlaDevice(
- options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
- padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
+ device->reset(
+ new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
+ platform.ValueOrDie(), transfer_as_literal,
+ use_multiple_streams, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- PaddedShapeFn padded_shape_fn)
+ PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)),
- padded_shape_fn_(std::move(padded_shape_fn)) {}
+ padded_shape_fn_(std::move(padded_shape_fn)),
+ use_multiple_streams_(use_multiple_streams) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- se::Platform* platform, bool transfer_as_literal,
+ se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn, padded_shape_fn),
+ shape_representation_fn, padded_shape_fn,
+ use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
+ use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
@@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
+xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!device_to_host_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(device_to_host_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return device_to_host_stream_.get();
+}
+
+xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!host_to_device_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(host_to_device_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return host_to_device_stream_.get();
+}
+
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
@@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
- gpu_device_info_->default_context = new XlaDeviceContext(
- stream, client(), transfer_as_literal_, shape_representation_fn_);
+ gpu_device_info_->default_context =
+ new XlaDeviceContext(stream, stream, stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
- auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ auto ctx = new XlaDeviceContext(
+ stream, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- XlaTransferManager manager(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+ XlaTransferManager manager(stream, host_to_device_stream,
+ device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;