diff options
author | 2016-02-23 09:17:19 -0800 | |
---|---|---|
committer | 2016-02-23 09:58:44 -0800 | |
commit | 082263066a23afed631fc04462784773e68d3d71 (patch) | |
tree | 511647dabacf4cdb071c4dfad5300926cb7a1ca4 | |
parent | 35fa8c4ef0a60f854e40491518153260dde0b8da (diff) |
Reimplement StrippedOpListForGraph in C++
Change: 115347996
-rw-r--r-- | tensorflow/core/framework/graph_def_util.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/framework/graph_def_util.h | 17 | ||||
-rw-r--r-- | tensorflow/core/framework/graph_def_util_test.cc | 41 |
3 files changed, 82 insertions, 0 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index fbe4118223..2eb23bdf9c 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" +#include <set> #include <vector> #include "tensorflow/core/framework/node_def_util.h" @@ -115,4 +116,27 @@ Status RemoveNewDefaultAttrsFromGraphDef( return s; } +Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list) { + stripped_op_list->clear_op(); + + // Collect the sorted list of op names + std::set<string> used_ops; + for (const auto& node : graph_def.node()) { + used_ops.insert(node.op()); + } + + // Build the stripped op list in sorted order. + Status status; + for (const string& op_name : used_ops) { + const OpDef* op = op_registry.LookUp(op_name, &status); + if (!op) return status; + OpDef* stripped_op = stripped_op_list->add_op(); + stripped_op->CopyFrom(*op); + RemoveDescriptionsFromOpDef(stripped_op); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index e03636e22f..ac1595ca33 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -77,11 +77,28 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, // graph_def, consumer_op_registry)); // // Consumer can use 'graph_def', and 'op_attr_removed' summarizes // // what changes had to be made to 'graph_def' for it to work. +// +// TODO(josh11b): Describe how to use this function on the consumer using the +// stripped_op_list field from a producer. Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, std::set<std::pair<string, string>>* op_attr_removed); +// Collect the ops used by a graph. +// +// This function computes the stripped_op_list field of MetaGraphDef and similar +// protos. The op_registry should contain the ops used to produce graph_def, +// and stripped_op_list can be used as the producer_op_registry argument to +// RemoveNewDefaultAttrsFromGraphDef to improve forwards compatibility +// (using OpListOpRegistry to turn the OpList into an OpRegistryInterface). +// +// Most users will pass OpRegistry::Global() for op_registry to strip against +// the list of ops registered in this process. +Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list); + } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 3caa7f3e43..777782b64b 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -133,5 +133,46 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { EXPECT_TRUE(op_attr_removed.empty()); } +TEST(StrippedOpListForGraphTest, StripTest) { + // Make four ops + OpList op_list; + for (const string& op : {"A", "B", "C", "D"}) { + OpDef* op_def = op_list.add_op(); + op_def->set_name(op); + op_def->set_summary("summary"); + op_def->set_description("description"); + op_def->set_is_commutative(op == "B"); + } + + // Make a graph which uses two ops once and twice, respectively. + // The result should be independent of the ordering. + const string graph_ops[4][3] = { + {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}}; + for (int order = 0; order < 4; order++) { + GraphDef graph_def; + for (const string& op : graph_ops[order]) { + string name = strings::StrCat("name", graph_def.node_size()); + NodeDef* node = graph_def.add_node(); + node->set_name(name); + node->set_op(op); + } + + // Strip the op list + OpList stripped_op_list; + TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list), + &stripped_op_list)); + + // We should have exactly two ops: B and C. + ASSERT_EQ(stripped_op_list.op_size(), 2); + for (int i = 0; i < 2; i++) { + const OpDef& op = stripped_op_list.op(i); + EXPECT_EQ(op.name(), i ? "C" : "B"); + EXPECT_EQ(op.summary(), ""); + EXPECT_EQ(op.description(), ""); + EXPECT_EQ(op.is_commutative(), !i); + } + } +} + } // namespace } // namespace tensorflow |