aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-27 15:29:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 15:34:18 -0700
commitfc492c08d64f05b1d68beedf11a3b55bf8066a8b (patch)
treee893aaf05e305f5389141735d132e6d50bd9e5ea /tensorflow/compiler/jit/xla_launch_util.cc
parentdf6c8721f8be706291a8151d725de4435942f7e2 (diff)
Support returning resource handles from function in XLA
There are a couple of reasons to do this: - resource handle are regular tensors part of a public API that can potentially be returned from a function. - When tfe.defun is executed under GradientTape, it generates a function returning resource handles in certain cases. This CL adds support for returning resource handles from an XLA compiled function. These resource handles must have been passed as arguments to the function. In other words, we don't yet support returning resources created inside the function. tfe.defun never makes functions that create resources. PiperOrigin-RevId: 210442856
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc47
1 files changed, 26 insertions, 21 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 2ffce9298d..affeab4a8c 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
- VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
-
- se::DeviceMemoryBase buffer = output.buffer({output_num});
- if (allocate_xla_tensors_) {
- Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
- XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- if (xla_tensor) {
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
- if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ const DataType& type = kernel->outputs[i].type;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
+ << DataTypeString(type);
+ if (type == DT_RESOURCE) {
+ ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
+ } else {
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
+ if (allocate_xla_tensors_) {
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ xla_tensor->SetDefinedOn(stream, definition_event);
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
- // xla_tensor wasn't valid, which must mean this is a zero-element
- // tensor.
- CHECK_EQ(output_tensor->TotalBytes(), 0);
+ Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+ ctx->expected_output_dtype(i), shape, buffer, allocator);
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
+ ctx->set_output(i, output_tensor);
}
- } else {
- Tensor output_tensor = XlaTensorBuffer::MakeTensor(
- ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(xla::OwningDeviceMemory(), {output_num});
- ctx->set_output(i, output_tensor);
+ ++output_num;
}
- ++output_num;
}
if (VLOG_IS_ON(3)) {