diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-23 08:18:19 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-23 08:25:11 -0800 |
commit | 542c0d6fd433d50f6558bfcc1ed2cab3cc52a4e5 (patch) | |
tree | 18cfbfc2326598caf9451749a62be683ec1d40e6 | |
parent | b0dba1afec56f39854be6cd40fe38c0d554afd13 (diff) |
Fix RemoveNewDefaultAttrsFromGraphDef() for graphs with functions.
Change: 140034297
-rw-r--r-- | tensorflow/core/framework/graph_def_util.cc | 108 | ||||
-rw-r--r-- | tensorflow/core/framework/graph_def_util_test.cc | 53 |
2 files changed, 122 insertions, 39 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index e6af336c7f..58fb8cf611 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -66,56 +66,86 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, return Status::OK(); } +static Status RemoveNewDefaultAttrsFromNodeDef( + NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, + const OpRegistryInterface& producer_op_registry, + std::set<std::pair<string, string>>* op_attr_removed) { + const OpDef* producer_op_def; + const OpDef* consumer_op_def; + TF_RETURN_IF_ERROR( + producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); + TF_RETURN_IF_ERROR( + consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); + + std::vector<string> to_remove; + for (const auto& attr : node_def->attr()) { + // If the attr is not in consumer_op_def and doesn't start with '_'... + if (!StringPiece(attr.first).starts_with("_") && + FindAttr(attr.first, *consumer_op_def) == nullptr) { + const OpDef::AttrDef* producer_attr_def = + FindAttr(attr.first, *producer_op_def); + if (producer_attr_def == nullptr) { + return errors::InvalidArgument( + "Attr '", attr.first, "' missing in producer's OpDef: ", + SummarizeOpDef(*producer_op_def), " but found in node: ", + SummarizeNodeDef(*node_def)); + } + // ...and it has the same value as the default in producer, + if (producer_attr_def->has_default_value() && + AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) { + // then we will remove it below. + to_remove.emplace_back(attr.first); + } + } + } + // We separate identifying which attrs should be removed from + // actually removing them to avoid invalidating the loop iterators + // above. + for (const string& attr_name : to_remove) { + node_def->mutable_attr()->erase(attr_name); + if (op_attr_removed != nullptr) { + op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); + } + } + + return Status::OK(); +} + +static bool IsFunction(const GraphDef& graph_def, const string& op_name) { + for (const auto& func_def : graph_def.library().function()) { + if (op_name == func_def.signature().name()) return true; + } + return false; +} + Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, std::set<std::pair<string, string>>* op_attr_removed) { - Status s; - std::vector<string> to_remove; + // TODO(joshL): Make IsFunction() faster by collecting the names of + // all functions as a preprocessing step. for (int n = 0; n < graph_def->node_size(); ++n) { NodeDef* node_def = graph_def->mutable_node(n); - const OpDef* producer_op_def; - const OpDef* consumer_op_def; - - TF_RETURN_IF_ERROR( - producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); - TF_RETURN_IF_ERROR( - consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); - - for (const auto& attr : node_def->attr()) { - // If the attr is not in consumer_op_def and doesn't start with '_'... - if (!StringPiece(attr.first).starts_with("_") && - FindAttr(attr.first, *consumer_op_def) == nullptr) { - const OpDef::AttrDef* producer_attr_def = - FindAttr(attr.first, *producer_op_def); - if (producer_attr_def == nullptr) { - return errors::InvalidArgument( - "Attr '", attr.first, "' missing in producer's OpDef: ", - SummarizeOpDef(*producer_op_def), " but found in node: ", - SummarizeNodeDef(*node_def)); - } - // ...and it has the same value as the default in producer, - if (producer_attr_def->has_default_value() && - AreAttrValuesEqual(producer_attr_def->default_value(), - attr.second)) { - // then we will remove it below. - to_remove.emplace_back(attr.first); - } - } + if (!IsFunction(*graph_def, node_def->op())) { + TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( + node_def, consumer_op_registry, producer_op_registry, + op_attr_removed)); } - // We separate identifying which attrs should be removed from - // actually removing them to avoid invalidating the loop iterators - // above. - for (const string& attr_name : to_remove) { - node_def->mutable_attr()->erase(attr_name); - if (op_attr_removed != nullptr) { - op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); + } + for (int f = 0; f < graph_def->library().function_size(); ++f) { + FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f); + for (int n = 0; n < func_def->node_def_size(); ++n) { + NodeDef* node_def = func_def->mutable_node_def(n); + if (!IsFunction(*graph_def, node_def->op())) { + // TODO(josh11b): Better handling of attrs with placeholder values. + TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( + node_def, consumer_op_registry, producer_op_registry, + op_attr_removed)); } } - to_remove.clear(); } - return s; + return Status::OK(); } void OpsUsedByGraph(const GraphDef& graph_def, diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 98f6e9b89b..dacfab93e9 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" @@ -171,6 +172,58 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { EXPECT_EQ(op_attr_removed.size(), 0); } +TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { + OpList consumer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), + consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + *produced_graph_def.mutable_library()->add_function() = + FunctionDefHelper::Create( + "my_func", {}, {}, {}, + {{{"x"}, "UsesDefault", {}, {{"a", 17}}}, + {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, + {}); + OpList function_op_list; + *function_op_list.add_op() = + produced_graph_def.library().function(0).signature(); + OpListOpRegistry function_registry(&function_op_list); + TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) + .Finalize(produced_graph_def.add_node())); + + std::set<std::pair<string, string>> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + GraphDef expected_graph_def; + *expected_graph_def.mutable_library()->add_function() = + FunctionDefHelper::Create( + "my_func", {}, {}, {}, + {{{"x"}, "UsesDefault", {}, {}}, + {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, + {}); + TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) + .Finalize(expected_graph_def.add_node())); + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + EXPECT_EQ(expected_graph_def.library().DebugString(), + produced_graph_def.library().DebugString()); + + std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}}); + EXPECT_EQ(expected_removed, op_attr_removed); +} + TEST(StrippedOpListForGraphTest, FlatTest) { // Make four ops OpList op_list; |