aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-23 18:13:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-23 18:26:20 -0800
commit3b4e53b0739804af7e8f51412bac366dd842a3f1 (patch)
treeacd7523cb4fdeb2f40031d83c70730b2553b1099
parentcd4a96499b5f83d82f2612f8e5aa99726cf15edb (diff)
Add an options argument to EqualGraphDef and EqualNodeDef. Currently the only option is controlling whether internal attributes (whose names start with "_") are tested for equality.
Change: 145362690
-rw-r--r--tensorflow/core/graph/equal_graph_def.cc16
-rw-r--r--tensorflow/core/graph/equal_graph_def.h13
2 files changed, 20 insertions, 9 deletions
diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc
index 0c019fc5c1..21b6d55ca8 100644
--- a/tensorflow/core/graph/equal_graph_def.cc
+++ b/tensorflow/core/graph/equal_graph_def.cc
@@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
- string* diff) {
+ string* diff, const EqualGraphDefOptions& options) {
// Intentionally do not check that versions match so that this routine can
// be used for less brittle golden file tests.
@@ -44,7 +44,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
return false;
}
- if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
+ if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
+ return false;
+ }
actual_index.erase(actual_iter);
}
@@ -75,8 +77,8 @@ string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
} // namespace
-bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
- string* diff) {
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
+ const EqualGraphDefOptions& options) {
if (actual.name() != expected.name()) {
if (diff != nullptr) {
*diff = strings::StrCat("Actual node name '", actual.name(),
@@ -156,13 +158,15 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
std::unordered_set<string> actual_attr;
for (const auto& a : actual.attr()) {
- if (!a.first.empty() && a.first[0] == '_') {
+ if (options.ignore_internal_attrs && !a.first.empty() &&
+ a.first[0] == '_') {
continue;
}
actual_attr.insert(a.first);
}
for (const auto& e : expected.attr()) {
- if (!e.first.empty() && e.first[0] == '_') {
+ if (options.ignore_internal_attrs && !e.first.empty() &&
+ e.first[0] == '_') {
continue;
}
diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h
index 8d997fdff8..82f8bd0713 100644
--- a/tensorflow/core/graph/equal_graph_def.h
+++ b/tensorflow/core/graph/equal_graph_def.h
@@ -22,20 +22,27 @@ limitations under the License.
namespace tensorflow {
+struct EqualGraphDefOptions {
+ // Should internal attributes (attribute names that start with '_') be
+ // ignored?
+ bool ignore_internal_attrs = true;
+};
+
// Determines if actual and expected are equal, ignoring versions and ordering
// of nodes, attrs, and control inputs. If the GraphDefs are different and
// diff != nullptr, *diff is set to an explanation of the difference. Note that
// we use node names to match up nodes between the graphs, and so the naming of
// nodes must be consistent.
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
- string* diff);
+ string* diff, const EqualGraphDefOptions& options = {});
// Determines if actual and expected are equal, ignoring: ordering of
-// attrs, internal attributes, and control inputs.
+// attrs, internal attributes (if set in `options`), and control inputs.
//
// If the NodeDefs are different and
// diff != nullptr, *diff is set to an explanation of the difference.
-bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
+ const EqualGraphDefOptions& options = {});
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
do { \