aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-02-23 09:17:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-23 09:58:44 -0800
commit082263066a23afed631fc04462784773e68d3d71 (patch)
tree511647dabacf4cdb071c4dfad5300926cb7a1ca4
parent35fa8c4ef0a60f854e40491518153260dde0b8da (diff)
Reimplement StrippedOpListForGraph in C++
Change: 115347996
-rw-r--r--tensorflow/core/framework/graph_def_util.cc24
-rw-r--r--tensorflow/core/framework/graph_def_util.h17
-rw-r--r--tensorflow/core/framework/graph_def_util_test.cc41
3 files changed, 82 insertions, 0 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index fbe4118223..2eb23bdf9c 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
+#include <set>
#include <vector>
#include "tensorflow/core/framework/node_def_util.h"
@@ -115,4 +116,27 @@ Status RemoveNewDefaultAttrsFromGraphDef(
return s;
}
+Status StrippedOpListForGraph(const GraphDef& graph_def,
+ const OpRegistryInterface& op_registry,
+ OpList* stripped_op_list) {
+ stripped_op_list->clear_op();
+
+ // Collect the sorted list of op names
+ std::set<string> used_ops;
+ for (const auto& node : graph_def.node()) {
+ used_ops.insert(node.op());
+ }
+
+ // Build the stripped op list in sorted order.
+ Status status;
+ for (const string& op_name : used_ops) {
+ const OpDef* op = op_registry.LookUp(op_name, &status);
+ if (!op) return status;
+ OpDef* stripped_op = stripped_op_list->add_op();
+ stripped_op->CopyFrom(*op);
+ RemoveDescriptionsFromOpDef(stripped_op);
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
index e03636e22f..ac1595ca33 100644
--- a/tensorflow/core/framework/graph_def_util.h
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -77,11 +77,28 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
// graph_def, consumer_op_registry));
// // Consumer can use 'graph_def', and 'op_attr_removed' summarizes
// // what changes had to be made to 'graph_def' for it to work.
+//
+// TODO(josh11b): Describe how to use this function on the consumer using the
+// stripped_op_list field from a producer.
Status RemoveNewDefaultAttrsFromGraphDef(
GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
const OpRegistryInterface& producer_op_registry,
std::set<std::pair<string, string>>* op_attr_removed);
+// Collect the ops used by a graph.
+//
+// This function computes the stripped_op_list field of MetaGraphDef and similar
+// protos. The op_registry should contain the ops used to produce graph_def,
+// and stripped_op_list can be used as the producer_op_registry argument to
+// RemoveNewDefaultAttrsFromGraphDef to improve forwards compatibility
+// (using OpListOpRegistry to turn the OpList into an OpRegistryInterface).
+//
+// Most users will pass OpRegistry::Global() for op_registry to strip against
+// the list of ops registered in this process.
+Status StrippedOpListForGraph(const GraphDef& graph_def,
+ const OpRegistryInterface& op_registry,
+ OpList* stripped_op_list);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc
index 3caa7f3e43..777782b64b 100644
--- a/tensorflow/core/framework/graph_def_util_test.cc
+++ b/tensorflow/core/framework/graph_def_util_test.cc
@@ -133,5 +133,46 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
EXPECT_TRUE(op_attr_removed.empty());
}
+TEST(StrippedOpListForGraphTest, StripTest) {
+ // Make four ops
+ OpList op_list;
+ for (const string& op : {"A", "B", "C", "D"}) {
+ OpDef* op_def = op_list.add_op();
+ op_def->set_name(op);
+ op_def->set_summary("summary");
+ op_def->set_description("description");
+ op_def->set_is_commutative(op == "B");
+ }
+
+ // Make a graph which uses two ops once and twice, respectively.
+ // The result should be independent of the ordering.
+ const string graph_ops[4][3] = {
+ {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
+ for (int order = 0; order < 4; order++) {
+ GraphDef graph_def;
+ for (const string& op : graph_ops[order]) {
+ string name = strings::StrCat("name", graph_def.node_size());
+ NodeDef* node = graph_def.add_node();
+ node->set_name(name);
+ node->set_op(op);
+ }
+
+ // Strip the op list
+ OpList stripped_op_list;
+ TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
+ &stripped_op_list));
+
+ // We should have exactly two ops: B and C.
+ ASSERT_EQ(stripped_op_list.op_size(), 2);
+ for (int i = 0; i < 2; i++) {
+ const OpDef& op = stripped_op_list.op(i);
+ EXPECT_EQ(op.name(), i ? "C" : "B");
+ EXPECT_EQ(op.summary(), "");
+ EXPECT_EQ(op.description(), "");
+ EXPECT_EQ(op.is_commutative(), !i);
+ }
+ }
+}
+
} // namespace
} // namespace tensorflow