aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-12-12 17:01:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 17:05:03 -0800
commite115b064f57f5c373f1acdb56b210c541ccf63fb (patch)
treeb7b2d8a7c4eeed1dab962b8defa3b1285391293c
parent618d5c5fad4f70456856625322db104b851a399d (diff)
[TF] Mark DT_STRING and DT_RESOURCE types as always sitting on host memory.
This is important when these arguments may appear in op input lists or output lists, where the signature may not be able to declare them as sitting on host. For DT_RESOURCE types, just the handles are marked as sitting on host memory; the actual data may reside on GPU. PiperOrigin-RevId: 178837213
-rw-r--r--tensorflow/core/framework/memory_types.cc17
-rw-r--r--tensorflow/core/framework/memory_types_test.cc20
-rw-r--r--tensorflow/core/framework/types.cc12
-rw-r--r--tensorflow/core/framework/types.h5
4 files changed, 47 insertions, 7 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 6a2eed94b9..270118bb67 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -61,7 +61,8 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
}
MemoryType MTypeFromDType(const DataType dtype) {
- return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
+ return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
+ : DEVICE_MEMORY;
}
} // namespace
@@ -118,6 +119,20 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
}
+ CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
+ CHECK_LE(out_mtypes->size(), out_dtypes.size());
+
+ // Mark e.g. all resource and string types as host memory.
+ for (int i = 0; i < inp_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
+ (*inp_mtypes)[i] = HOST_MEMORY;
+ }
+ }
+ for (int i = 0; i < out_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(out_dtypes[i])) {
+ (*out_mtypes)[i] = HOST_MEMORY;
+ }
+ }
std::vector<int32> hostmem_attr;
if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc
index 4704da9a11..3126ea8e5f 100644
--- a/tensorflow/core/framework/memory_types_test.cc
+++ b/tensorflow/core/framework/memory_types_test.cc
@@ -36,11 +36,13 @@ REGISTER_OP("HostMemoryTest")
.Input("b: T")
.Input("c: N * string")
.Input("d: Tlist")
+ .Input("e: Rlist")
.Output("o: N * T")
.Output("p: Tlist")
.Attr("T: type")
.Attr("N: int")
- .Attr("Tlist: list(type)");
+ .Attr("Tlist: list(type)")
+ .Attr("Rlist: list(type)");
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU)
@@ -57,15 +59,20 @@ TEST(MemoryTypesForNode, Simple) {
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(3))
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
+ .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
.Finalize(&node_def));
MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
&input, &output));
- EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY}),
- input);
+ // a:float, b:bool, c:3*string, d:(int32, float, int32),
+ // e:(resource, string, resource)
+ EXPECT_EQ(
+ MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+ DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ input);
+ // o:3*bool, p:(int32, float, int32)
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
output);
@@ -74,7 +81,8 @@ TEST(MemoryTypesForNode, Simple) {
&input, &output));
EXPECT_EQ(
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
- HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index b082dfbd03..58354d6f4e 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) {
}
}
+bool DataTypeAlwaysOnHost(DataType dt) {
+ // Includes DT_STRING and DT_RESOURCE.
+ switch (dt) {
+ case DT_STRING:
+ case DT_STRING_REF:
+ case DT_RESOURCE:
+ return true;
+ default:
+ return false;
+ }
+}
+
bool DataTypeIsFloating(DataType dt) {
switch (dt) {
case DT_HALF:
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 652985658a..27005c0e93 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -239,6 +239,11 @@ bool DataTypeIsUnsigned(DataType dt);
// Returns a 0 on failure
int DataTypeSize(DataType dt);
+// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
+// For DT_RESOURCE, the handle always sits on host (even if the underlying
+// object has device-allocated resources).
+bool DataTypeAlwaysOnHost(DataType dt);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_TYPES_H_