diff options
author | 2017-12-21 21:49:22 -0800 | |
---|---|---|
committer | 2017-12-21 21:52:18 -0800 | |
commit | f1a1bd87f86f29d1bdeff5b60521eda1cbd863ad (patch) | |
tree | cdefd55fbb3620cc38d4c2108fd51f087d50b72b /tensorflow/python/grappler | |
parent | b2aa6950db67ab980012c05d496401200ad60320 (diff) |
Made hard colocation constraints (i.e constraints that must be met for the
model to be executable) available from python
PiperOrigin-RevId: 179892785
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/item.i | 136 | ||||
-rw-r--r-- | tensorflow/python/grappler/item.py | 11 | ||||
-rw-r--r-- | tensorflow/python/grappler/item_test.py | 19 |
3 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index eb396ef1ad..d0fc1a04f2 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -42,6 +42,8 @@ struct GItem { #include <unordered_set> #include <map> #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/grappler_item_builder.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -163,6 +165,139 @@ static PyObject* TF_GetOpProperties(GItem item) { return props; } +class ColocationGroups { +public: + void Group(const string& x, const string& y) { + Rep* x_root = Find(x); + Rep* y_root = Find(y); + + // x and y are already in the same set + if (x_root == y_root) { + return; + } + // x and y are not in same set, so we merge them + // Use the occasion to strengthen what we know about the handle by merging the + // information about the 2 subsets. + if (x_root->rank < y_root->rank) { + x_root->parent = y_root; + } else if (x_root->rank > y_root->rank) { + y_root->parent = x_root; + } else { + // Arbitrarily make one root the new parent + y_root->parent = x_root; + x_root->rank = x_root->rank + 1; + } + } + + void ExtractGroups(std::vector<std::vector<string>>* groups) { + groups->reserve(nodes_.size()); + std::unordered_map<const Rep*, int> group_ids; + for (const auto& rep : nodes_) { + Rep* r = Find(rep.first); + auto it = group_ids.find(r); + std::vector<string>* g; + if (it == group_ids.end()) { + int id = group_ids.size(); + group_ids[r] = id; + groups->resize(id+1); + g = &groups->back(); + } else { + int id = it->second; + g = &((*groups)[id]); + } + g->push_back(rep.first); + } + } + +private: + struct Rep { + // Parent in the tree used to encode the set. + Rep* parent; + // Rank in the tree, used to figure out how to compress the path to the root + // of the tree. + int rank; + // The node. + string value; + }; + + Rep* Find(const string& n) { + auto it = nodes_.find(n); + if (it == nodes_.end()) { + // This is the first time we process this handle, create an entry for it. + Rep* node = new Rep; + node->parent = node; + node->rank = 0; + node->value = n; + nodes_[n] = node; + return node; + } + // Return the representative for the set, which is the root of the tree. Apply + // path compression to speedup future queries. + Rep* node = it->second; + Rep* root = node->parent; + while (root != root->parent) { + root = root->parent; + } + while (node->parent != root) { + Rep* next = node->parent; + node->parent = root; + node = next; + } + return root; + } + + std::unordered_map<string, Rep*> nodes_; +}; + +static PyObject* TF_GetColocationGroups(GItem item) { + if (item.is_none()) { + Py_RETURN_NONE; + } + ColocationGroups groupings; + tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); + for (const auto& node : item->graph.node()) { + const tensorflow::OpDef* op_def; + tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def); + if (!s.ok()) { + continue; + } + tensorflow::NameRangeMap inputs; + tensorflow::NameRangeMap outputs; + s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs); + if (!s.ok()) { + continue; + } + int i = 0; + for (const auto& arg : op_def->input_arg()) { + if (!arg.is_ref()) { + continue; + } + const auto& range = inputs[arg.name()]; + for (int i = range.first; i < range.second; ++i) { + string input = tensorflow::grappler::NodeName(node.input(i)); + groupings.Group(node.name(), input); + } + } + } + + std::vector<std::vector<string>> groups; + groupings.ExtractGroups(&groups); + + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyList_New(groups.size()); + for (int i = 0; i < groups.size(); ++i) { + const std::vector<string>& group = groups[i]; + PyObject* g = PyTuple_New(group.size()); + for (int j = 0; j < group.size(); ++j) { + const string& node_name = group[j]; + PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str())); + } + PyList_SetItem(result, i, g); + } + PyGILState_Release(gstate); + return result; +} + %} @@ -173,3 +308,4 @@ static GItem TF_NewItem( static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically, TF_Status* status); static PyObject* TF_GetOpProperties(GItem item); +static PyObject* TF_GetColocationGroups(GItem item); diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py index c6e66d3c27..4a083849bd 100644 --- a/tensorflow/python/grappler/item.py +++ b/tensorflow/python/grappler/item.py @@ -66,6 +66,17 @@ class Item(object): properties[key] = prop return properties + def GetColocationGroups(self): + """Return a list of hard colocation constraints. + + All the nodes in a colocation tuple must be placed on the same device for + the model to work. + + Returns: + A list of colocation tuples. + """ + return tf_item.TF_GetColocationGroups(self.tf_item) + @property def metagraph(self): return self._metagraph diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py index 71c68d25cd..7f0e141a7d 100644 --- a/tensorflow/python/grappler/item_test.py +++ b/tensorflow/python/grappler/item_test.py @@ -26,6 +26,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.grappler import item from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -104,6 +107,22 @@ class ItemTest(test.TestCase): newest_tf_item = grappler_item.tf_item self.assertEqual(new_tf_item, newest_tf_item) + def testColocationContraints(self): + with ops.Graph().as_default() as g: + c = constant_op.constant([10]) + v = variables.Variable([3], dtype=dtypes.int32) + i = gen_array_ops._ref_identity(v) + a = state_ops.assign(i, c) + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(a) + mg = meta_graph.create_meta_graph_def(graph=g) + grappler_item = item.Item(mg) + groups = grappler_item.GetColocationGroups() + self.assertEqual(len(groups), 1) + self.assertEqual( + sorted(groups[0]), + ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign']) + if __name__ == '__main__': test.main() |