aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-06-14 11:38:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-14 11:44:27 -0700
commitea29274ba39389f5958b21d2dc2990a305303fc3 (patch)
treefe5a80c4014e18a6843193d58ee1c1bee47d41d8
parent7c2cf8362201adfd40ced8b5e5b731591ac4a232 (diff)
Register _Arg for resources on GPU
PiperOrigin-RevId: 159002331
-rw-r--r--tensorflow/core/common_runtime/function.cc11
-rw-r--r--tensorflow/core/kernels/function_ops.cc6
2 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 4970c2d252..320e214117 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -326,13 +326,14 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
- // TODO(zhifengc): For now, we assume int32 is always on host memory
- // and other types are always on device memory. We should do type
- // inference over function body to derive the correct input/output
- // memory types.
+ // TODO(zhifengc): For now, we assume int32 and resources are always on host
+ // memory and other types are always on device memory. We should do type
+ // inference over function body to derive the correct input/output memory
+ // types.
MemoryTypeVector input_memory_types;
for (const auto& t : fbody->arg_types) {
- input_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
+ input_memory_types.push_back(
+ (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY);
}
MemoryTypeVector output_memory_types;
for (const auto& t : fbody->ret_types) {
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 8c3137ece9..58c4ed37c4 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -122,6 +122,12 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
ArgOp);
#undef REGISTER
+REGISTER_KERNEL_BUILDER(Name("_Arg")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<ResourceHandle>("T"),
+ ArgOp);
+
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("_Retval").Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);