diff options
-rw-r--r-- | tensorflow/core/framework/memory_types.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/framework/memory_types_test.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/framework/types.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/types.h | 5 |
4 files changed, 47 insertions, 7 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 6a2eed94b9..270118bb67 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -61,7 +61,8 @@ void MemoryTypesHelper(const NameRangeMap& name_map, } MemoryType MTypeFromDType(const DataType dtype) { - return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY; + return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY + : DEVICE_MEMORY; } } // namespace @@ -118,6 +119,20 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, "HostMemory args '", str_util::Join(host_memory_args, "', '"), "' not found in OpDef: ", SummarizeOpDef(*op_def)); } + CHECK_LE(inp_mtypes->size(), inp_dtypes.size()); + CHECK_LE(out_mtypes->size(), out_dtypes.size()); + + // Mark e.g. all resource and string types as host memory. + for (int i = 0; i < inp_mtypes->size(); ++i) { + if (DataTypeAlwaysOnHost(inp_dtypes[i])) { + (*inp_mtypes)[i] = HOST_MEMORY; + } + } + for (int i = 0; i < out_mtypes->size(); ++i) { + if (DataTypeAlwaysOnHost(out_dtypes[i])) { + (*out_mtypes)[i] = HOST_MEMORY; + } + } std::vector<int32> hostmem_attr; if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) { diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc index 4704da9a11..3126ea8e5f 100644 --- a/tensorflow/core/framework/memory_types_test.cc +++ b/tensorflow/core/framework/memory_types_test.cc @@ -36,11 +36,13 @@ REGISTER_OP("HostMemoryTest") .Input("b: T") .Input("c: N * string") .Input("d: Tlist") + .Input("e: Rlist") .Output("o: N * T") .Output("p: Tlist") .Attr("T: type") .Attr("N: int") - .Attr("Tlist: list(type)"); + .Attr("Tlist: list(type)") + .Attr("Rlist: list(type)"); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") .Device(DEVICE_GPU) @@ -57,15 +59,20 @@ TEST(MemoryTypesForNode, Simple) { .Input(FakeInput(DT_BOOL)) .Input(FakeInput(3)) .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32})) + .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE})) .Finalize(&node_def)); MemoryTypeVector input, output; TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, &input, &output)); - EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, - DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, - DEVICE_MEMORY, DEVICE_MEMORY}), - input); + // a:float, b:bool, c:3*string, d:(int32, float, int32), + // e:(resource, string, resource) + EXPECT_EQ( + MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + input); + // o:3*bool, p:(int32, float, int32) EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), output); @@ -74,7 +81,8 @@ TEST(MemoryTypesForNode, Simple) { &input, &output)); EXPECT_EQ( MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, - HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), input); EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index b082dfbd03..58354d6f4e 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) { } } +bool DataTypeAlwaysOnHost(DataType dt) { + // Includes DT_STRING and DT_RESOURCE. + switch (dt) { + case DT_STRING: + case DT_STRING_REF: + case DT_RESOURCE: + return true; + default: + return false; + } +} + bool DataTypeIsFloating(DataType dt) { switch (dt) { case DT_HALF: diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 652985658a..27005c0e93 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -239,6 +239,11 @@ bool DataTypeIsUnsigned(DataType dt); // Returns a 0 on failure int DataTypeSize(DataType dt); +// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. +// For DT_RESOURCE, the handle always sits on host (even if the underlying +// object has device-allocated resources). +bool DataTypeAlwaysOnHost(DataType dt); + } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_TYPES_H_ |