aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-23 08:18:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 08:25:11 -0800
commit542c0d6fd433d50f6558bfcc1ed2cab3cc52a4e5 (patch)
tree18cfbfc2326598caf9451749a62be683ec1d40e6
parentb0dba1afec56f39854be6cd40fe38c0d554afd13 (diff)
Fix RemoveNewDefaultAttrsFromGraphDef() for graphs with functions.
Change: 140034297
-rw-r--r--tensorflow/core/framework/graph_def_util.cc108
-rw-r--r--tensorflow/core/framework/graph_def_util_test.cc53
2 files changed, 122 insertions, 39 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index e6af336c7f..58fb8cf611 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -66,56 +66,86 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
return Status::OK();
}
+static Status RemoveNewDefaultAttrsFromNodeDef(
+ NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
+ const OpRegistryInterface& producer_op_registry,
+ std::set<std::pair<string, string>>* op_attr_removed) {
+ const OpDef* producer_op_def;
+ const OpDef* consumer_op_def;
+ TF_RETURN_IF_ERROR(
+ producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
+ TF_RETURN_IF_ERROR(
+ consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
+
+ std::vector<string> to_remove;
+ for (const auto& attr : node_def->attr()) {
+ // If the attr is not in consumer_op_def and doesn't start with '_'...
+ if (!StringPiece(attr.first).starts_with("_") &&
+ FindAttr(attr.first, *consumer_op_def) == nullptr) {
+ const OpDef::AttrDef* producer_attr_def =
+ FindAttr(attr.first, *producer_op_def);
+ if (producer_attr_def == nullptr) {
+ return errors::InvalidArgument(
+ "Attr '", attr.first, "' missing in producer's OpDef: ",
+ SummarizeOpDef(*producer_op_def), " but found in node: ",
+ SummarizeNodeDef(*node_def));
+ }
+ // ...and it has the same value as the default in producer,
+ if (producer_attr_def->has_default_value() &&
+ AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
+ // then we will remove it below.
+ to_remove.emplace_back(attr.first);
+ }
+ }
+ }
+ // We separate identifying which attrs should be removed from
+ // actually removing them to avoid invalidating the loop iterators
+ // above.
+ for (const string& attr_name : to_remove) {
+ node_def->mutable_attr()->erase(attr_name);
+ if (op_attr_removed != nullptr) {
+ op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
+ }
+ }
+
+ return Status::OK();
+}
+
+static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
+ for (const auto& func_def : graph_def.library().function()) {
+ if (op_name == func_def.signature().name()) return true;
+ }
+ return false;
+}
+
Status RemoveNewDefaultAttrsFromGraphDef(
GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
const OpRegistryInterface& producer_op_registry,
std::set<std::pair<string, string>>* op_attr_removed) {
- Status s;
- std::vector<string> to_remove;
+ // TODO(joshL): Make IsFunction() faster by collecting the names of
+ // all functions as a preprocessing step.
for (int n = 0; n < graph_def->node_size(); ++n) {
NodeDef* node_def = graph_def->mutable_node(n);
- const OpDef* producer_op_def;
- const OpDef* consumer_op_def;
-
- TF_RETURN_IF_ERROR(
- producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
- TF_RETURN_IF_ERROR(
- consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
-
- for (const auto& attr : node_def->attr()) {
- // If the attr is not in consumer_op_def and doesn't start with '_'...
- if (!StringPiece(attr.first).starts_with("_") &&
- FindAttr(attr.first, *consumer_op_def) == nullptr) {
- const OpDef::AttrDef* producer_attr_def =
- FindAttr(attr.first, *producer_op_def);
- if (producer_attr_def == nullptr) {
- return errors::InvalidArgument(
- "Attr '", attr.first, "' missing in producer's OpDef: ",
- SummarizeOpDef(*producer_op_def), " but found in node: ",
- SummarizeNodeDef(*node_def));
- }
- // ...and it has the same value as the default in producer,
- if (producer_attr_def->has_default_value() &&
- AreAttrValuesEqual(producer_attr_def->default_value(),
- attr.second)) {
- // then we will remove it below.
- to_remove.emplace_back(attr.first);
- }
- }
+ if (!IsFunction(*graph_def, node_def->op())) {
+ TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
+ node_def, consumer_op_registry, producer_op_registry,
+ op_attr_removed));
}
- // We separate identifying which attrs should be removed from
- // actually removing them to avoid invalidating the loop iterators
- // above.
- for (const string& attr_name : to_remove) {
- node_def->mutable_attr()->erase(attr_name);
- if (op_attr_removed != nullptr) {
- op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
+ }
+ for (int f = 0; f < graph_def->library().function_size(); ++f) {
+ FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
+ for (int n = 0; n < func_def->node_def_size(); ++n) {
+ NodeDef* node_def = func_def->mutable_node_def(n);
+ if (!IsFunction(*graph_def, node_def->op())) {
+ // TODO(josh11b): Better handling of attrs with placeholder values.
+ TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
+ node_def, consumer_op_registry, producer_op_registry,
+ op_attr_removed));
}
}
- to_remove.clear();
}
- return s;
+ return Status::OK();
}
void OpsUsedByGraph(const GraphDef& graph_def,
diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc
index 98f6e9b89b..dacfab93e9 100644
--- a/tensorflow/core/framework/graph_def_util_test.cc
+++ b/tensorflow/core/framework/graph_def_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
@@ -171,6 +172,58 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
EXPECT_EQ(op_attr_removed.size(), 0);
}
+TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
+ OpList consumer_op_list;
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
+ consumer_op_list.add_op()));
+ OpListOpRegistry consumer_registry(&consumer_op_list);
+
+ OpList producer_op_list;
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
+ producer_op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
+ producer_op_list.add_op()));
+ OpListOpRegistry producer_registry(&producer_op_list);
+
+ GraphDef produced_graph_def;
+ *produced_graph_def.mutable_library()->add_function() =
+ FunctionDefHelper::Create(
+ "my_func", {}, {}, {},
+ {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
+ {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
+ {});
+ OpList function_op_list;
+ *function_op_list.add_op() =
+ produced_graph_def.library().function(0).signature();
+ OpListOpRegistry function_registry(&function_op_list);
+ TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
+ .Finalize(produced_graph_def.add_node()));
+
+ std::set<std::pair<string, string>> op_attr_removed;
+ TF_ASSERT_OK(
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
+ producer_registry, &op_attr_removed));
+
+ GraphDef expected_graph_def;
+ *expected_graph_def.mutable_library()->add_function() =
+ FunctionDefHelper::Create(
+ "my_func", {}, {}, {},
+ {{{"x"}, "UsesDefault", {}, {}},
+ {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
+ {});
+ TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
+ .Finalize(expected_graph_def.add_node()));
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
+ EXPECT_EQ(expected_graph_def.library().DebugString(),
+ produced_graph_def.library().DebugString());
+
+ std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
+ EXPECT_EQ(expected_removed, op_attr_removed);
+}
+
TEST(StrippedOpListForGraphTest, FlatTest) {
// Make four ops
OpList op_list;