diff options
author | 2017-06-14 11:38:44 -0700 | |
---|---|---|
committer | 2017-06-14 11:44:27 -0700 | |
commit | ea29274ba39389f5958b21d2dc2990a305303fc3 (patch) | |
tree | fe5a80c4014e18a6843193d58ee1c1bee47d41d8 | |
parent | 7c2cf8362201adfd40ced8b5e5b731591ac4a232 (diff) |
Register _Arg for resources on GPU
PiperOrigin-RevId: 159002331
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/kernels/function_ops.cc | 6 |
2 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4970c2d252..320e214117 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -326,13 +326,14 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); - // TODO(zhifengc): For now, we assume int32 is always on host memory - // and other types are always on device memory. We should do type - // inference over function body to derive the correct input/output - // memory types. + // TODO(zhifengc): For now, we assume int32 and resources are always on host + // memory and other types are always on device memory. We should do type + // inference over function body to derive the correct input/output memory + // types. MemoryTypeVector input_memory_types; for (const auto& t : fbody->arg_types) { - input_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY); + input_memory_types.push_back( + (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY); } MemoryTypeVector output_memory_types; for (const auto& t : fbody->ret_types) { diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 8c3137ece9..58c4ed37c4 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -122,6 +122,12 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") ArgOp); #undef REGISTER +REGISTER_KERNEL_BUILDER(Name("_Arg") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint<ResourceHandle>("T"), + ArgOp); + #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ Name("_Retval").Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp); |