diff options
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 61 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types_test.cc | 27 |
3 files changed, 105 insertions, 49 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 1ccf66ed34..5137180479 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -219,10 +219,34 @@ class PassOn : public OpKernel { } } }; + REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn); -REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_GPU), PassOn); REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn); -REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_GPU), PassOn); + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + PassOn); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + PassOn); + +REGISTER_GPU_KERNELS(Eigen::half); +REGISTER_GPU_KERNELS(float); +REGISTER_GPU_KERNELS(double); + +#undef REGISTER_GPU_KERNELS + +REGISTER_KERNEL_BUILDER(Name("_ListToArray") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint<int32>("T"), + PassOn); +REGISTER_KERNEL_BUILDER(Name("_ArrayToList") + .Device(DEVICE_GPU) + .HostMemory("input") + .TypeConstraint<int32>("T"), + PassOn); static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; @@ -383,23 +407,22 @@ class SymbolicGradientOp : public AsyncOpKernel { args.push_back(ctx->input(i)); } std::vector<Tensor>* rets = new std::vector<Tensor>; - lib->Run(opts, handle_, args, rets, - [ctx, done, rets](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else if (rets->size() != ctx->num_outputs()) { - ctx->SetStatus(errors::InvalidArgument( - "SymGrad expects to return ", ctx->num_outputs(), - " tensor(s), but get ", rets->size(), - " tensor(s) instead.")); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - done(); - }); + lib->Run( + opts, handle_, args, rets, [ctx, done, rets](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else if (rets->size() != ctx->num_outputs()) { + ctx->SetStatus(errors::InvalidArgument( + "SymGrad expects to return ", ctx->num_outputs(), + " tensor(s), but get ", rets->size(), " tensor(s) instead.")); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + done(); + }); } private: diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index d538228494..6e37a3aba4 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -40,10 +40,6 @@ int GetTotal(const NameRangeMap& name_map) { void MemoryTypesHelper(const NameRangeMap& name_map, std::vector<string>* host_memory_args, MemoryTypeVector* memory_types) { - // Now that we know the size, fill with the default 'DEVICE_MEMORY'. - memory_types->clear(); - memory_types->resize(GetTotal(name_map), DEVICE_MEMORY); - // Update args that have been marked as in "HOST_MEMORY". size_t keep = 0; for (size_t i = 0; i < host_memory_args->size(); ++i) { @@ -65,15 +61,27 @@ MemoryType MTypeFromDType(const DataType dtype) { return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY; } -// Returns true if an arg of op_def's input/output is a type list. -bool HasTypeList(const OpDef& op_def) { - for (const auto& a : op_def.input_arg()) { - if (!a.type_list_attr().empty()) return true; - } - for (const auto& a : op_def.output_arg()) { - if (!a.type_list_attr().empty()) return true; +// Initialize the default memory types for type list arguments from the data +// types. (The default can be overridden by an explicit HostMemory() +// declaration.) +Status SetTypeListMTypesFromDTypes( + const NameRangeMap& name_ranges, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + const DataTypeVector& dtypes, MemoryTypeVector* mtypes) { + for (const auto& a : args) { + if (!a.type_list_attr().empty()) { + auto it = name_ranges.find(a.name()); + if (it == name_ranges.end()) { + return errors::InvalidArgument("Name range for argument ", a.name(), + " not found."); + } + + for (int i = it->second.first; i < it->second.second; ++i) { + (*mtypes)[i] = MTypeFromDType(dtypes[i]); + } + } } - return false; + return Status::OK(); } } // namespace @@ -91,20 +99,21 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, Status status = FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); - if (!status.ok() || HasTypeList(*op_def)) { - // When there is no kernel def for this op or the op's arg is a - // type list, we can only best-effort derive the memory type from - // the data type. For now, we assume int32 is always on host - // memory and other types are always on device memory. We should + DataTypeVector inp_dtypes; + DataTypeVector out_dtypes; + TF_RETURN_IF_ERROR( + InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes)); + + inp_mtypes->clear(); + out_mtypes->clear(); + + if (!status.ok()) { + // When there is no kernel def for this op, we can only best-effort derive + // the memory type from the data type. For now, we assume int32 is always + // on host memory and other types are always on device memory. We should // do type inference over function body to derive the correct // input/output memory types. - DataTypeVector inp_dtypes; - DataTypeVector out_dtypes; - TF_RETURN_IF_ERROR( - InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes)); - inp_mtypes->clear(); for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); - out_mtypes->clear(); for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); return Status::OK(); } @@ -114,6 +123,16 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, NameRangeMap out_names; TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); + out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); + + // For type list arguments, mark int32 arguments as host memory. + TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(inp_names, op_def->input_arg(), + inp_dtypes, inp_mtypes)); + TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes( + out_names, op_def->output_arg(), out_dtypes, out_mtypes)); + // Fills in host memory types based on the kernel def. const auto& from_proto = kdef->host_memory_arg(); std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); @@ -124,6 +143,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, "HostMemory args '", str_util::Join(host_memory_args, "', '"), "' not found in OpDef: ", SummarizeOpDef(*op_def)); } + return Status::OK(); } diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc index c4d5886ffb..c4cd875bc4 100644 --- a/tensorflow/core/framework/memory_types_test.cc +++ b/tensorflow/core/framework/memory_types_test.cc @@ -35,14 +35,18 @@ REGISTER_OP("HostMemoryTest") .Input("a: float") .Input("b: T") .Input("c: N * string") + .Input("d: Tlist") .Output("o: N * T") + .Output("p: Tlist") .Attr("T: type") - .Attr("N: int"); + .Attr("N: int") + .Attr("Tlist: list(type)"); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") .Device(DEVICE_GPU) .HostMemory("a") .HostMemory("c") + .HostMemory("d") .HostMemory("o"), DummyKernel); @@ -52,20 +56,29 @@ TEST(MemoryTypesForNode, Simple) { .Input(FakeInput()) .Input(FakeInput(DT_BOOL)) .Input(FakeInput(3)) + .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32})) .Finalize(&node_def)); MemoryTypeVector input, output; TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, &input, &output)); - EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); - EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); + EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + DEVICE_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}), + output); TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, &input, &output)); - EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, - HOST_MEMORY, HOST_MEMORY}), - input); - EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); + EXPECT_EQ( + MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}), + output); } } // namespace tensorflow |