aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/function.cc61
-rw-r--r--tensorflow/core/framework/memory_types.cc66
-rw-r--r--tensorflow/core/framework/memory_types_test.cc27
3 files changed, 105 insertions, 49 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 1ccf66ed34..5137180479 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -219,10 +219,34 @@ class PassOn : public OpKernel {
}
}
};
+
REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
-REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_GPU), PassOn);
REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
-REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_GPU), PassOn);
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ PassOn); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ PassOn);
+
+REGISTER_GPU_KERNELS(Eigen::half);
+REGISTER_GPU_KERNELS(float);
+REGISTER_GPU_KERNELS(double);
+
+#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_ListToArray")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ PassOn);
+REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
+ .Device(DEVICE_GPU)
+ .HostMemory("input")
+ .TypeConstraint<int32>("T"),
+ PassOn);
static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
@@ -383,23 +407,22 @@ class SymbolicGradientOp : public AsyncOpKernel {
args.push_back(ctx->input(i));
}
std::vector<Tensor>* rets = new std::vector<Tensor>;
- lib->Run(opts, handle_, args, rets,
- [ctx, done, rets](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else if (rets->size() != ctx->num_outputs()) {
- ctx->SetStatus(errors::InvalidArgument(
- "SymGrad expects to return ", ctx->num_outputs(),
- " tensor(s), but get ", rets->size(),
- " tensor(s) instead."));
- } else {
- for (size_t i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- done();
- });
+ lib->Run(
+ opts, handle_, args, rets, [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else if (rets->size() != ctx->num_outputs()) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "SymGrad expects to return ", ctx->num_outputs(),
+ " tensor(s), but get ", rets->size(), " tensor(s) instead."));
+ } else {
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
}
private:
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index d538228494..6e37a3aba4 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -40,10 +40,6 @@ int GetTotal(const NameRangeMap& name_map) {
void MemoryTypesHelper(const NameRangeMap& name_map,
std::vector<string>* host_memory_args,
MemoryTypeVector* memory_types) {
- // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
- memory_types->clear();
- memory_types->resize(GetTotal(name_map), DEVICE_MEMORY);
-
// Update args that have been marked as in "HOST_MEMORY".
size_t keep = 0;
for (size_t i = 0; i < host_memory_args->size(); ++i) {
@@ -65,15 +61,27 @@ MemoryType MTypeFromDType(const DataType dtype) {
return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
}
-// Returns true if an arg of op_def's input/output is a type list.
-bool HasTypeList(const OpDef& op_def) {
- for (const auto& a : op_def.input_arg()) {
- if (!a.type_list_attr().empty()) return true;
- }
- for (const auto& a : op_def.output_arg()) {
- if (!a.type_list_attr().empty()) return true;
+// Initialize the default memory types for type list arguments from the data
+// types. (The default can be overridden by an explicit HostMemory()
+// declaration.)
+Status SetTypeListMTypesFromDTypes(
+ const NameRangeMap& name_ranges,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const DataTypeVector& dtypes, MemoryTypeVector* mtypes) {
+ for (const auto& a : args) {
+ if (!a.type_list_attr().empty()) {
+ auto it = name_ranges.find(a.name());
+ if (it == name_ranges.end()) {
+ return errors::InvalidArgument("Name range for argument ", a.name(),
+ " not found.");
+ }
+
+ for (int i = it->second.first; i < it->second.second; ++i) {
+ (*mtypes)[i] = MTypeFromDType(dtypes[i]);
+ }
+ }
}
- return false;
+ return Status::OK();
}
} // namespace
@@ -91,20 +99,21 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
Status status =
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
- if (!status.ok() || HasTypeList(*op_def)) {
- // When there is no kernel def for this op or the op's arg is a
- // type list, we can only best-effort derive the memory type from
- // the data type. For now, we assume int32 is always on host
- // memory and other types are always on device memory. We should
+ DataTypeVector inp_dtypes;
+ DataTypeVector out_dtypes;
+ TF_RETURN_IF_ERROR(
+ InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes));
+
+ inp_mtypes->clear();
+ out_mtypes->clear();
+
+ if (!status.ok()) {
+ // When there is no kernel def for this op, we can only best-effort derive
+ // the memory type from the data type. For now, we assume int32 is always
+ // on host memory and other types are always on device memory. We should
// do type inference over function body to derive the correct
// input/output memory types.
- DataTypeVector inp_dtypes;
- DataTypeVector out_dtypes;
- TF_RETURN_IF_ERROR(
- InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes));
- inp_mtypes->clear();
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
- out_mtypes->clear();
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
return Status::OK();
}
@@ -114,6 +123,16 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
NameRangeMap out_names;
TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
+ // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
+ inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
+ out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
+
+ // For type list arguments, mark int32 arguments as host memory.
+ TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(inp_names, op_def->input_arg(),
+ inp_dtypes, inp_mtypes));
+ TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(
+ out_names, op_def->output_arg(), out_dtypes, out_mtypes));
+
// Fills in host memory types based on the kernel def.
const auto& from_proto = kdef->host_memory_arg();
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
@@ -124,6 +143,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
}
+
return Status::OK();
}
diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc
index c4d5886ffb..c4cd875bc4 100644
--- a/tensorflow/core/framework/memory_types_test.cc
+++ b/tensorflow/core/framework/memory_types_test.cc
@@ -35,14 +35,18 @@ REGISTER_OP("HostMemoryTest")
.Input("a: float")
.Input("b: T")
.Input("c: N * string")
+ .Input("d: Tlist")
.Output("o: N * T")
+ .Output("p: Tlist")
.Attr("T: type")
- .Attr("N: int");
+ .Attr("N: int")
+ .Attr("Tlist: list(type)");
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU)
.HostMemory("a")
.HostMemory("c")
+ .HostMemory("d")
.HostMemory("o"),
DummyKernel);
@@ -52,20 +56,29 @@ TEST(MemoryTypesForNode, Simple) {
.Input(FakeInput())
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(3))
+ .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
.Finalize(&node_def));
MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
&input, &output));
- EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input);
- EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output);
+ EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+ DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
+ DEVICE_MEMORY, HOST_MEMORY}),
+ input);
+ EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+ HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
+ output);
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
&input, &output));
- EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
- HOST_MEMORY, HOST_MEMORY}),
- input);
- EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output);
+ EXPECT_EQ(
+ MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ input);
+ EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
+ output);
}
} // namespace tensorflow