aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-27 15:29:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 15:34:18 -0700
commitfc492c08d64f05b1d68beedf11a3b55bf8066a8b (patch)
treee893aaf05e305f5389141735d132e6d50bd9e5ea /tensorflow/compiler/jit
parentdf6c8721f8be706291a8151d725de4435942f7e2 (diff)
Support returning resource handles from function in XLA
There are a couple of reasons to do this: - resource handle are regular tensors part of a public API that can potentially be returned from a function. - When tfe.defun is executed under GradientTape, it generates a function returning resource handles in certain cases. This CL adds support for returning resource handles from an XLA compiled function. These resource handles must have been passed as arguments to the function. In other words, we don't yet support returning resources created inside the function. tfe.defun never makes functions that create resources. PiperOrigin-RevId: 210442856
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc7
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc47
2 files changed, 32 insertions, 22 deletions
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index a7f8a5613c..56b034a30b 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -209,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
- // in device memory
+ // in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ for (int i = 0; i < fbody->ret_types.size(); ++i) {
+ if (fbody->ret_types[i] == DT_RESOURCE) {
+ output_memory_types[i] = HOST_MEMORY;
+ }
+ }
// Create the kernel.
NameAttrList function;
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 2ffce9298d..affeab4a8c 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
- VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
-
- se::DeviceMemoryBase buffer = output.buffer({output_num});
- if (allocate_xla_tensors_) {
- Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
- XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- if (xla_tensor) {
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
- if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ const DataType& type = kernel->outputs[i].type;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
+ << DataTypeString(type);
+ if (type == DT_RESOURCE) {
+ ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
+ } else {
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
+ if (allocate_xla_tensors_) {
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ xla_tensor->SetDefinedOn(stream, definition_event);
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
- // xla_tensor wasn't valid, which must mean this is a zero-element
- // tensor.
- CHECK_EQ(output_tensor->TotalBytes(), 0);
+ Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+ ctx->expected_output_dtype(i), shape, buffer, allocator);
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
+ ctx->set_output(i, output_tensor);
}
- } else {
- Tensor output_tensor = XlaTensorBuffer::MakeTensor(
- ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(xla::OwningDeviceMemory(), {output_num});
- ctx->set_output(i, output_tensor);
+ ++output_num;
}
- ++output_num;
}
if (VLOG_IS_ON(3)) {