aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/clusters/single_machine_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/clusters/single_machine_test.cc')
-rw-r--r--tensorflow/core/grappler/clusters/single_machine_test.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index 0572aa04be..17db48817e 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -159,6 +159,121 @@ TEST_F(SingleMachineTest, InitializationMemory) {
EXPECT_TRUE(found);
}
+namespace {
+template <class T>
+inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
+ AttrValue attr_value;
+ SetAttrValue(value, &attr_value);
+ auto* attr_map = node->mutable_attr();
+ (*attr_map)[key] = attr_value;
+}
+template <>
+inline void SetNodeAttr(const string& key, const Tensor& tensor,
+ NodeDef* node) {
+ TensorProto tensor_proto;
+ tensor.AsProtoTensorContent(&tensor_proto);
+ SetNodeAttr(key, tensor_proto, node);
+}
+
+} // namespace
+
+TEST_F(SingleMachineTest, PersistentMemory) {
+ // Build a hashtable and its initialization graph.
+ GrapplerItem item;
+ const DataType key_dtype = DT_INT64;
+ const DataType data_dtype = DT_INT64;
+
+ NodeDef* hashtable_node = item.graph.add_node();
+ hashtable_node->set_op("HashTable");
+ hashtable_node->set_name("hash_table");
+ SetNodeAttr("key_dtype", key_dtype, hashtable_node);
+ SetNodeAttr("value_dtype", data_dtype, hashtable_node);
+
+ // Initial hashtable keys and values
+ NodeDef* keys_node = item.graph.add_node();
+ keys_node->set_op("Const");
+ keys_node->set_name("table_keys");
+ SetNodeAttr("dtype", key_dtype, keys_node);
+ Tensor keys(key_dtype, TensorShape{2});
+ keys.vec<int64>()(0) = 123;
+ keys.vec<int64>()(1) = 321;
+ SetNodeAttr("value", keys, keys_node);
+
+ NodeDef* values_node = item.graph.add_node();
+ values_node->set_op("Const");
+ values_node->set_name("table_values");
+ SetNodeAttr("dtype", data_dtype, values_node);
+ Tensor values(data_dtype, TensorShape{2});
+ values.vec<int64>()(0) = 789;
+ values.vec<int64>()(1) = 987;
+ SetNodeAttr("value", values, values_node);
+
+ // InitializeTable node
+ NodeDef* init_table_node = item.graph.add_node();
+ init_table_node->set_op("InitializeTable");
+ init_table_node->set_name("initialize_table");
+ SetNodeAttr("Tkey", key_dtype, init_table_node);
+ SetNodeAttr("Tval", data_dtype, init_table_node);
+ *init_table_node->add_input() = "hash_table";
+ *init_table_node->add_input() = "table_keys";
+ *init_table_node->add_input() = "table_values";
+ item.init_ops.push_back(init_table_node->name());
+
+ // Key to lookup
+ NodeDef* query_node = item.graph.add_node();
+ query_node->set_op("Const");
+ query_node->set_name("query");
+ SetNodeAttr("dtype", key_dtype, query_node);
+ Tensor query(key_dtype, TensorShape({}));
+ query.flat<int64>()(0) = 0;
+ SetNodeAttr("value", query, query_node);
+
+ // Default return value of hashtable lookup
+ NodeDef* default_value_node = item.graph.add_node();
+ default_value_node->set_op("Const");
+ default_value_node->set_name("default_table_value");
+ SetNodeAttr("dtype", data_dtype, default_value_node);
+ Tensor dflt(data_dtype, TensorShape({}));
+ dflt.flat<int64>()(0) = 456;
+ SetNodeAttr("value", dflt, default_value_node);
+
+ // HashTable lookup node
+ NodeDef* lookup_node = item.graph.add_node();
+ lookup_node->set_op("LookupTableFind");
+ lookup_node->set_name("table_lookup");
+ SetNodeAttr("Tin", key_dtype, lookup_node);
+ SetNodeAttr("Tout", data_dtype, lookup_node);
+ *lookup_node->add_input() = "hash_table";
+ *lookup_node->add_input() = "query";
+ *lookup_node->add_input() = "default_table_value";
+ item.fetch.push_back(lookup_node->name());
+
+ // Run the graph
+ TF_CHECK_OK(cluster_->Initialize(item));
+ RunMetadata metadata;
+ TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
+
+ // Check the cost model.
+ bool found_table_init = false;
+ bool found_hashtable = false;
+ for (const auto& node : metadata.cost_graph().node()) {
+ if (node.name() == "hash_table") {
+ found_hashtable = true;
+ // Persistent memory usage should be 0 since it's recorded as part of the
+ // initialize_table op.
+ EXPECT_EQ(0, node.host_persistent_memory_size());
+ EXPECT_EQ(0, node.device_persistent_memory_size());
+ } else if (node.name() == "initialize_table") {
+ found_table_init = true;
+ // Persistent memory should hold 2 keys and 2 values.
+ EXPECT_LE(4 * sizeof(int64), node.host_persistent_memory_size());
+ EXPECT_EQ(0, node.device_persistent_memory_size());
+ }
+ }
+ EXPECT_TRUE(found_table_init);
+ EXPECT_TRUE(found_hashtable);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow