aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/memory_types.cc17
-rw-r--r--tensorflow/core/framework/memory_types_test.cc20
-rw-r--r--tensorflow/core/framework/types.cc12
-rw-r--r--tensorflow/core/framework/types.h5
4 files changed, 47 insertions, 7 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 6a2eed94b9..270118bb67 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -61,7 +61,8 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
}
MemoryType MTypeFromDType(const DataType dtype) {
- return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
+ return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
+ : DEVICE_MEMORY;
}
} // namespace
@@ -118,6 +119,20 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
}
+ CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
+ CHECK_LE(out_mtypes->size(), out_dtypes.size());
+
+ // Mark e.g. all resource and string types as host memory.
+ for (int i = 0; i < inp_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
+ (*inp_mtypes)[i] = HOST_MEMORY;
+ }
+ }
+ for (int i = 0; i < out_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(out_dtypes[i])) {
+ (*out_mtypes)[i] = HOST_MEMORY;
+ }
+ }
std::vector<int32> hostmem_attr;
if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc
index 4704da9a11..3126ea8e5f 100644
--- a/tensorflow/core/framework/memory_types_test.cc
+++ b/tensorflow/core/framework/memory_types_test.cc
@@ -36,11 +36,13 @@ REGISTER_OP("HostMemoryTest")
.Input("b: T")
.Input("c: N * string")
.Input("d: Tlist")
+ .Input("e: Rlist")
.Output("o: N * T")
.Output("p: Tlist")
.Attr("T: type")
.Attr("N: int")
- .Attr("Tlist: list(type)");
+ .Attr("Tlist: list(type)")
+ .Attr("Rlist: list(type)");
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU)
@@ -57,15 +59,20 @@ TEST(MemoryTypesForNode, Simple) {
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(3))
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
+ .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
.Finalize(&node_def));
MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
&input, &output));
- EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY}),
- input);
+ // a:float, b:bool, c:3*string, d:(int32, float, int32),
+ // e:(resource, string, resource)
+ EXPECT_EQ(
+ MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+ DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ input);
+ // o:3*bool, p:(int32, float, int32)
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
output);
@@ -74,7 +81,8 @@ TEST(MemoryTypesForNode, Simple) {
&input, &output));
EXPECT_EQ(
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
- HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index b082dfbd03..58354d6f4e 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) {
}
}
+bool DataTypeAlwaysOnHost(DataType dt) {
+ // Includes DT_STRING and DT_RESOURCE.
+ switch (dt) {
+ case DT_STRING:
+ case DT_STRING_REF:
+ case DT_RESOURCE:
+ return true;
+ default:
+ return false;
+ }
+}
+
bool DataTypeIsFloating(DataType dt) {
switch (dt) {
case DT_HALF:
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 652985658a..27005c0e93 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -239,6 +239,11 @@ bool DataTypeIsUnsigned(DataType dt);
// Returns a 0 on failure
int DataTypeSize(DataType dt);
+// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
+// For DT_RESOURCE, the handle always sits on host (even if the underlying
+// object has device-allocated resources).
+bool DataTypeAlwaysOnHost(DataType dt);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_TYPES_H_