diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_device.cc | 70 |
1 files changed, 55 insertions, 15 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index ed007d603e..c55eba2f79 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" @@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice( - options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, - padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); + device->reset( + new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, + use_multiple_streams, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn) + PaddedShapeFn padded_shape_fn, bool use_multiple_streams) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), shape_representation_fn_(std::move(shape_representation_fn)), - padded_shape_fn_(std::move(padded_shape_fn)) {} + padded_shape_fn_(std::move(padded_shape_fn)), + use_multiple_streams_(use_multiple_streams) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, + se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn, padded_shape_fn), + shape_representation_fn, padded_shape_fn, + use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), + use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; @@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() { return stream_.get(); } +xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() { + if (!use_multiple_streams_) { + return GetStream(); + } + if (!device_to_host_stream_) { + xla::Backend* backend = client()->mutable_backend(); + TF_ASSIGN_OR_RETURN(device_to_host_stream_, + backend->BorrowStream(device_ordinal_)); + } + return device_to_host_stream_.get(); +} + +xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() { + if (!use_multiple_streams_) { + return GetStream(); + } + if (!host_to_device_stream_) { + xla::Backend* backend = client()->mutable_backend(); + TF_ASSIGN_OR_RETURN(host_to_device_stream_, + backend->BorrowStream(device_ordinal_)); + } + return host_to_device_stream_.get(); +} + Status XlaDevice::CreateAndSetGpuDeviceInfo() { if (gpu_device_info_ == nullptr) { TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); @@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { // gpu_device_info_->default_context. gpu_device_info_ = MakeUnique<GpuDeviceInfo>(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = new XlaDeviceContext( - stream, client(), transfer_as_literal_, shape_representation_fn_); + gpu_device_info_->default_context = + new XlaDeviceContext(stream, stream, stream, client(), + transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph, VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); + TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, + GetDeviceToHostStream()); + TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, + GetHostToDeviceStream()); + // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, - shape_representation_fn_); + auto ctx = new XlaDeviceContext( + stream, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_, - shape_representation_fn_); + TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, + GetDeviceToHostStream()); + TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, + GetHostToDeviceStream()); + XlaTransferManager manager(stream, host_to_device_stream, + device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; |