diff options
author | Igor Ganichev <iga@google.com> | 2018-08-27 15:29:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 15:34:18 -0700 |
commit | fc492c08d64f05b1d68beedf11a3b55bf8066a8b (patch) | |
tree | e893aaf05e305f5389141735d132e6d50bd9e5ea /tensorflow/compiler/jit | |
parent | df6c8721f8be706291a8151d725de4435942f7e2 (diff) |
Support returning resource handles from function in XLA
There are a couple of reasons to do this:
- resource handle are regular tensors part of a public API that
can potentially be returned from a function.
- When tfe.defun is executed under GradientTape, it generates a
function returning resource handles in certain cases.
This CL adds support for returning resource handles from an XLA
compiled function. These resource handles must have been passed as
arguments to the function. In other words, we don't yet support
returning resources created inside the function. tfe.defun never
makes functions that create resources.
PiperOrigin-RevId: 210442856
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/create_xla_launch_op.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 47 |
2 files changed, 32 insertions, 22 deletions
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a7f8a5613c..56b034a30b 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -209,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, // device memory. // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory + // in device memory except for resources. MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } // Create the kernel. NameAttrList function; diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 2ffce9298d..affeab4a8c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - se::DeviceMemoryBase buffer = output.buffer({output_num}); - if (allocate_xla_tensors_) { - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); - if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + const DataType& type = kernel->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_RESOURCE) { + ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } - } else { - Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); + ++output_num; } - ++output_num; } if (VLOG_IS_ON(3)) { |