diff options
Diffstat (limited to 'tensorflow/core/grappler/clusters/single_machine_test.cc')
-rw-r--r-- | tensorflow/core/grappler/clusters/single_machine_test.cc | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc index 0572aa04be..17db48817e 100644 --- a/tensorflow/core/grappler/clusters/single_machine_test.cc +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -159,6 +159,121 @@ TEST_F(SingleMachineTest, InitializationMemory) { EXPECT_TRUE(found); } +namespace { +template <class T> +inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + auto* attr_map = node->mutable_attr(); + (*attr_map)[key] = attr_value; +} +template <> +inline void SetNodeAttr(const string& key, const Tensor& tensor, + NodeDef* node) { + TensorProto tensor_proto; + tensor.AsProtoTensorContent(&tensor_proto); + SetNodeAttr(key, tensor_proto, node); +} + +} // namespace + +TEST_F(SingleMachineTest, PersistentMemory) { + // Build a hashtable and its initialization graph. + GrapplerItem item; + const DataType key_dtype = DT_INT64; + const DataType data_dtype = DT_INT64; + + NodeDef* hashtable_node = item.graph.add_node(); + hashtable_node->set_op("HashTable"); + hashtable_node->set_name("hash_table"); + SetNodeAttr("key_dtype", key_dtype, hashtable_node); + SetNodeAttr("value_dtype", data_dtype, hashtable_node); + + // Initial hashtable keys and values + NodeDef* keys_node = item.graph.add_node(); + keys_node->set_op("Const"); + keys_node->set_name("table_keys"); + SetNodeAttr("dtype", key_dtype, keys_node); + Tensor keys(key_dtype, TensorShape{2}); + keys.vec<int64>()(0) = 123; + keys.vec<int64>()(1) = 321; + SetNodeAttr("value", keys, keys_node); + + NodeDef* values_node = item.graph.add_node(); + values_node->set_op("Const"); + values_node->set_name("table_values"); + SetNodeAttr("dtype", data_dtype, values_node); + Tensor values(data_dtype, TensorShape{2}); + values.vec<int64>()(0) = 789; + values.vec<int64>()(1) = 987; + SetNodeAttr("value", values, values_node); + + // InitializeTable node + NodeDef* init_table_node = item.graph.add_node(); + init_table_node->set_op("InitializeTable"); + init_table_node->set_name("initialize_table"); + SetNodeAttr("Tkey", key_dtype, init_table_node); + SetNodeAttr("Tval", data_dtype, init_table_node); + *init_table_node->add_input() = "hash_table"; + *init_table_node->add_input() = "table_keys"; + *init_table_node->add_input() = "table_values"; + item.init_ops.push_back(init_table_node->name()); + + // Key to lookup + NodeDef* query_node = item.graph.add_node(); + query_node->set_op("Const"); + query_node->set_name("query"); + SetNodeAttr("dtype", key_dtype, query_node); + Tensor query(key_dtype, TensorShape({})); + query.flat<int64>()(0) = 0; + SetNodeAttr("value", query, query_node); + + // Default return value of hashtable lookup + NodeDef* default_value_node = item.graph.add_node(); + default_value_node->set_op("Const"); + default_value_node->set_name("default_table_value"); + SetNodeAttr("dtype", data_dtype, default_value_node); + Tensor dflt(data_dtype, TensorShape({})); + dflt.flat<int64>()(0) = 456; + SetNodeAttr("value", dflt, default_value_node); + + // HashTable lookup node + NodeDef* lookup_node = item.graph.add_node(); + lookup_node->set_op("LookupTableFind"); + lookup_node->set_name("table_lookup"); + SetNodeAttr("Tin", key_dtype, lookup_node); + SetNodeAttr("Tout", data_dtype, lookup_node); + *lookup_node->add_input() = "hash_table"; + *lookup_node->add_input() = "query"; + *lookup_node->add_input() = "default_table_value"; + item.fetch.push_back(lookup_node->name()); + + // Run the graph + TF_CHECK_OK(cluster_->Initialize(item)); + RunMetadata metadata; + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); + + // Check the cost model. + bool found_table_init = false; + bool found_hashtable = false; + for (const auto& node : metadata.cost_graph().node()) { + if (node.name() == "hash_table") { + found_hashtable = true; + // Persistent memory usage should be 0 since it's recorded as part of the + // initialize_table op. + EXPECT_EQ(0, node.host_persistent_memory_size()); + EXPECT_EQ(0, node.device_persistent_memory_size()); + } else if (node.name() == "initialize_table") { + found_table_init = true; + // Persistent memory should hold 2 keys and 2 values. + EXPECT_LE(4 * sizeof(int64), node.host_persistent_memory_size()); + EXPECT_EQ(0, node.device_persistent_memory_size()); + } + } + EXPECT_TRUE(found_table_init); + EXPECT_TRUE(found_hashtable); +} + } // namespace } // namespace grappler } // namespace tensorflow |