aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/kernels/xla_launch_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/kernels/xla_launch_op.cc')
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 251a07304e..c5d0e4f8fb 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -51,7 +51,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = se::cuda::kCudaPlatformId;
+ platform_id_ = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
} else {
platform_id_ = nullptr;
}
@@ -115,6 +119,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
+ bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
@@ -180,8 +185,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(client, xla_allocator,
- allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(
+ client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.