aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/retval_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/retval_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 64900e4709..e172c64932 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel {
} else {
xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
+ DataType input_type = ctx->input_type(0);
+ XlaContext& tc = XlaContext::Get(ctx);
+
+ if (input_type == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ ctx->SetStatus(tc.AddResourceRetval(index_, resource));
+ return;
+ }
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
@@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel {
return;
}
- XlaContext& tc = XlaContext::Get(ctx);
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
@@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
-REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
+REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
+ RetvalOp);
} // anonymous namespace
} // namespace tensorflow