aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/resource_op_kernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/resource_op_kernel.h')
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h25
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 813ec6eed5..0a8da8b3bf 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &handle_, nullptr));
+ has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
+ if (!has_resource_type_) {
+ // The resource variant of the op may be placed on non-CPU devices, but
+ // this allocation is always on the host. Fortunately we don't need it in
+ // the resource case.
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &handle_, nullptr));
+ }
}
// The resource is deleted from the resource manager only when it is private
@@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}
- auto h = handle_.AccessTensor(context)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
+ if (!has_resource_type_) {
+ auto h = handle_.AccessTensor(context)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ }
resource_ = resource;
}
- if (context->expected_output_dtype(0) == DT_RESOURCE) {
+ if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
@@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
+
+ // Is the output of the operator of type DT_RESOURCE?
+ bool has_resource_type_;
};
} // namespace tensorflow