aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-12-21 21:49:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 21:52:18 -0800
commitf1a1bd87f86f29d1bdeff5b60521eda1cbd863ad (patch)
treecdefd55fbb3620cc38d4c2108fd51f087d50b72b /tensorflow/python/grappler
parentb2aa6950db67ab980012c05d496401200ad60320 (diff)
Made hard colocation constraints (i.e constraints that must be met for the
model to be executable) available from python PiperOrigin-RevId: 179892785
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/item.i136
-rw-r--r--tensorflow/python/grappler/item.py11
-rw-r--r--tensorflow/python/grappler/item_test.py19
3 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i
index eb396ef1ad..d0fc1a04f2 100644
--- a/tensorflow/python/grappler/item.i
+++ b/tensorflow/python/grappler/item.i
@@ -42,6 +42,8 @@ struct GItem {
#include <unordered_set>
#include <map>
#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -163,6 +165,139 @@ static PyObject* TF_GetOpProperties(GItem item) {
return props;
}
+class ColocationGroups {
+public:
+ void Group(const string& x, const string& y) {
+ Rep* x_root = Find(x);
+ Rep* y_root = Find(y);
+
+ // x and y are already in the same set
+ if (x_root == y_root) {
+ return;
+ }
+ // x and y are not in same set, so we merge them
+ // Use the occasion to strengthen what we know about the handle by merging the
+ // information about the 2 subsets.
+ if (x_root->rank < y_root->rank) {
+ x_root->parent = y_root;
+ } else if (x_root->rank > y_root->rank) {
+ y_root->parent = x_root;
+ } else {
+ // Arbitrarily make one root the new parent
+ y_root->parent = x_root;
+ x_root->rank = x_root->rank + 1;
+ }
+ }
+
+ void ExtractGroups(std::vector<std::vector<string>>* groups) {
+ groups->reserve(nodes_.size());
+ std::unordered_map<const Rep*, int> group_ids;
+ for (const auto& rep : nodes_) {
+ Rep* r = Find(rep.first);
+ auto it = group_ids.find(r);
+ std::vector<string>* g;
+ if (it == group_ids.end()) {
+ int id = group_ids.size();
+ group_ids[r] = id;
+ groups->resize(id+1);
+ g = &groups->back();
+ } else {
+ int id = it->second;
+ g = &((*groups)[id]);
+ }
+ g->push_back(rep.first);
+ }
+ }
+
+private:
+ struct Rep {
+ // Parent in the tree used to encode the set.
+ Rep* parent;
+ // Rank in the tree, used to figure out how to compress the path to the root
+ // of the tree.
+ int rank;
+ // The node.
+ string value;
+ };
+
+ Rep* Find(const string& n) {
+ auto it = nodes_.find(n);
+ if (it == nodes_.end()) {
+ // This is the first time we process this handle, create an entry for it.
+ Rep* node = new Rep;
+ node->parent = node;
+ node->rank = 0;
+ node->value = n;
+ nodes_[n] = node;
+ return node;
+ }
+ // Return the representative for the set, which is the root of the tree. Apply
+ // path compression to speedup future queries.
+ Rep* node = it->second;
+ Rep* root = node->parent;
+ while (root != root->parent) {
+ root = root->parent;
+ }
+ while (node->parent != root) {
+ Rep* next = node->parent;
+ node->parent = root;
+ node = next;
+ }
+ return root;
+ }
+
+ std::unordered_map<string, Rep*> nodes_;
+};
+
+static PyObject* TF_GetColocationGroups(GItem item) {
+ if (item.is_none()) {
+ Py_RETURN_NONE;
+ }
+ ColocationGroups groupings;
+ tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
+ for (const auto& node : item->graph.node()) {
+ const tensorflow::OpDef* op_def;
+ tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def);
+ if (!s.ok()) {
+ continue;
+ }
+ tensorflow::NameRangeMap inputs;
+ tensorflow::NameRangeMap outputs;
+ s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs);
+ if (!s.ok()) {
+ continue;
+ }
+ int i = 0;
+ for (const auto& arg : op_def->input_arg()) {
+ if (!arg.is_ref()) {
+ continue;
+ }
+ const auto& range = inputs[arg.name()];
+ for (int i = range.first; i < range.second; ++i) {
+ string input = tensorflow::grappler::NodeName(node.input(i));
+ groupings.Group(node.name(), input);
+ }
+ }
+ }
+
+ std::vector<std::vector<string>> groups;
+ groupings.ExtractGroups(&groups);
+
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ PyObject* result = PyList_New(groups.size());
+ for (int i = 0; i < groups.size(); ++i) {
+ const std::vector<string>& group = groups[i];
+ PyObject* g = PyTuple_New(group.size());
+ for (int j = 0; j < group.size(); ++j) {
+ const string& node_name = group[j];
+ PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str()));
+ }
+ PyList_SetItem(result, i, g);
+ }
+ PyGILState_Release(gstate);
+ return result;
+}
+
%}
@@ -173,3 +308,4 @@ static GItem TF_NewItem(
static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically,
TF_Status* status);
static PyObject* TF_GetOpProperties(GItem item);
+static PyObject* TF_GetColocationGroups(GItem item);
diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py
index c6e66d3c27..4a083849bd 100644
--- a/tensorflow/python/grappler/item.py
+++ b/tensorflow/python/grappler/item.py
@@ -66,6 +66,17 @@ class Item(object):
properties[key] = prop
return properties
+ def GetColocationGroups(self):
+ """Return a list of hard colocation constraints.
+
+ All the nodes in a colocation tuple must be placed on the same device for
+ the model to work.
+
+ Returns:
+ A list of colocation tuples.
+ """
+ return tf_item.TF_GetColocationGroups(self.tf_item)
+
@property
def metagraph(self):
return self._metagraph
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index 71c68d25cd..7f0e141a7d 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -26,6 +26,9 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.grappler import item
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -104,6 +107,22 @@ class ItemTest(test.TestCase):
newest_tf_item = grappler_item.tf_item
self.assertEqual(new_tf_item, newest_tf_item)
+ def testColocationContraints(self):
+ with ops.Graph().as_default() as g:
+ c = constant_op.constant([10])
+ v = variables.Variable([3], dtype=dtypes.int32)
+ i = gen_array_ops._ref_identity(v)
+ a = state_ops.assign(i, c)
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(a)
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ grappler_item = item.Item(mg)
+ groups = grappler_item.GetColocationGroups()
+ self.assertEqual(len(groups), 1)
+ self.assertEqual(
+ sorted(groups[0]),
+ ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
+
if __name__ == '__main__':
test.main()