aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/tests
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-03-10 19:41:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 20:10:15 -0800
commit49ecd2a53dcc7aa57dfee9d669a1a4eff8c14fad (patch)
tree5873d864a407e99e506fc38c93979df82d7a71d5 /tensorflow/contrib/graph_editor/tests
parent2e7cc48e5cce0ff5429b2d9d0ac313ce70035605 (diff)
Automated rollback of change 149812873
Change: 149823781
Diffstat (limited to 'tensorflow/contrib/graph_editor/tests')
-rw-r--r--tensorflow/contrib/graph_editor/tests/edit_test.py7
-rw-r--r--tensorflow/contrib/graph_editor/tests/match.py158
-rw-r--r--tensorflow/contrib/graph_editor/tests/match_test.py21
-rw-r--r--tensorflow/contrib/graph_editor/tests/reroute_test.py23
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py13
5 files changed, 192 insertions, 30 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py
index 8adaf84b42..2f669c5d20 100644
--- a/tensorflow/contrib/graph_editor/tests/edit_test.py
+++ b/tensorflow/contrib/graph_editor/tests/edit_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import graph_editor as ge
+from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@@ -54,7 +55,7 @@ class EditTest(test.TestCase):
ge.detach(sgv, control_ios=control_outputs)
# make sure the detached graph is as expected.
self.assertTrue(
- ge.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
+ match.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
def test_connect(self):
"""Test for ge.connect."""
@@ -66,13 +67,13 @@ class EditTest(test.TestCase):
sgv = ge.sgv(x.op, y.op, z.op)
ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0]))
self.assertTrue(
- ge.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
+ match.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
def test_bypass(self):
"""Test for ge.bypass."""
ge.bypass(ge.sgv(self.f.op).remap_inputs([0]))
self.assertTrue(
- ge.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(
+ match.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(
self.h.op))
diff --git a/tensorflow/contrib/graph_editor/tests/match.py b/tensorflow/contrib/graph_editor/tests/match.py
new file mode 100644
index 0000000000..1bf482b6c2
--- /dev/null
+++ b/tensorflow/contrib/graph_editor/tests/match.py
@@ -0,0 +1,158 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple graph matching functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six import string_types
+
+from tensorflow.contrib.graph_editor import select
+from tensorflow.python.framework import ops as tf_ops
+
+__all__ = [
+ "op_type",
+ "OpMatcher",
+]
+
+
+def _make_graph_match(graph_match):
+ """Convert to a OpMatcher instance."""
+ if graph_match is None:
+ return None
+ if not isinstance(graph_match, OpMatcher):
+ graph_match = OpMatcher(graph_match)
+ return graph_match
+
+
+def op_type(op_types, op=None):
+ """Check if an op is of the given type.
+
+ Args:
+ op_types: tuple of strings containing the types to check against.
+ For instance: ("Add", "Const")
+ op: the operation to check (or None).
+ Returns:
+ if op is not None, return True if the op is of the correct type.
+ if op is None, return a lambda function which does the type checking.
+ """
+ if isinstance(op_types, string_types):
+ op_types = (op_types)
+ if op is None:
+ return lambda op: op.node_def.op in op_types
+ else:
+ return op.node_def.op in op_types
+
+
+class OpMatcher(object):
+ """Graph match class."""
+
+ def __init__(self, positive_filter):
+ """Graph match constructor."""
+ self.positive_filters = []
+ self.input_op_matches = None
+ self.control_input_op_matches = None
+ self.output_op_matches = None
+ positive_filter = self._finalize_positive_filter(positive_filter)
+ self.positive_filters.append(positive_filter)
+
+ def _finalize_positive_filter(self, elem):
+ """Convert to a filter function."""
+ if select.can_be_regex(elem):
+ regex_ = select.make_regex(elem)
+ return lambda op, regex=regex_: regex.search(op.name) is not None
+ elif isinstance(elem, tf_ops.Operation):
+ return lambda op, match_op=elem: op is match_op
+ elif callable(elem):
+ return elem
+ elif elem is True:
+ return lambda op: True
+ else:
+ raise ValueError("Cannot finalize the positive filter: {}".format(elem))
+
+ def __call__(self, op):
+ """Evaluate if the op matches or not."""
+ if not isinstance(op, tf_ops.Operation):
+ raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
+ for positive_filter in self.positive_filters:
+ if not positive_filter(op):
+ return False
+ if self.input_op_matches is not None:
+ if len(op.inputs) != len(self.input_op_matches):
+ return False
+ for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
+ if input_op_match is None:
+ continue
+ if not input_op_match(input_t.op):
+ return False
+ if self.control_input_op_matches is not None:
+ if len(op.control_inputs) != len(self.control_input_op_matches):
+ return False
+ for cinput_op, cinput_op_match in zip(op.control_inputs,
+ self.control_input_op_matches):
+ if cinput_op_match is None:
+ continue
+ if not cinput_op_match(cinput_op):
+ return False
+ if self.output_op_matches is not None:
+ if len(op.outputs) != len(self.output_op_matches):
+ return False
+ for output_t, output_op_matches in zip(op.outputs,
+ self.output_op_matches):
+ if output_op_matches is None:
+ continue
+ if len(output_t.consumers()) != len(output_op_matches):
+ return False
+ for consumer_op, consumer_op_match in zip(output_t.consumers(),
+ output_op_matches):
+ if consumer_op_match is None:
+ continue
+ if not consumer_op_match(consumer_op):
+ return False
+ return True
+
+ def input_ops(self, *args):
+ """Add input matches."""
+ if self.input_op_matches is not None:
+ raise ValueError("input_op_matches is already set.")
+ self.input_op_matches = []
+ for input_match in args:
+ self.input_op_matches.append(_make_graph_match(input_match))
+ return self
+
+ def control_input_ops(self, *args):
+ """Add input matches."""
+ if self.control_input_op_matches is not None:
+ raise ValueError("control_input_op_matches is already set.")
+ self.control_input_op_matches = []
+ for input_match in args:
+ self.control_input_op_matches.append(_make_graph_match(input_match))
+ return self
+
+ def output_ops(self, *args):
+ """Add output matches."""
+ if self.output_op_matches is not None:
+ raise ValueError("output_op_matches is already set.")
+ self.output_op_matches = []
+ for consumer_op_matches in args:
+ if consumer_op_matches is None:
+ self.output_op_matches.append(None)
+ if not isinstance(consumer_op_matches, list):
+ consumer_op_matches = [consumer_op_matches]
+ consumer_op_matches = [_make_graph_match(consumer_op_match)
+ for consumer_op_match in consumer_op_matches]
+ self.output_op_matches.append(consumer_op_matches)
+ return self
diff --git a/tensorflow/contrib/graph_editor/tests/match_test.py b/tensorflow/contrib/graph_editor/tests/match_test.py
index bcb8f3f0e3..d81dc34dba 100644
--- a/tensorflow/contrib/graph_editor/tests/match_test.py
+++ b/tensorflow/contrib/graph_editor/tests/match_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import graph_editor as ge
+from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@@ -42,20 +42,21 @@ class MatchTest(test.TestCase):
self.h = math_ops.add(self.f, self.g, name="h")
def test_simple_match(self):
- self.assertTrue(ge.OpMatcher("^.*/f$")(self.f.op))
+ self.assertTrue(match.OpMatcher("^.*/f$")(self.f.op))
self.assertTrue(
- ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op))
- self.assertTrue(ge.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op))
+ match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op))
self.assertTrue(
- ge.OpMatcher("^.*/f$").input_ops(
- ge.match.op_type("Add"), ge.match.op_type("Const"))(self.f.op))
+ match.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op))
self.assertTrue(
- ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")
- .output_ops(ge.OpMatcher("^.*/h$")
+ match.OpMatcher("^.*/f$").input_ops(
+ match.op_type("Add"), match.op_type("Const"))(self.f.op))
+ self.assertTrue(
+ match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")
+ .output_ops(match.OpMatcher("^.*/h$")
.control_input_ops("^.*/c$"))(self.f.op))
self.assertTrue(
- ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
- ge.OpMatcher("^.*/h$").control_input_ops("^.*/c$")
+ match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
+ match.OpMatcher("^.*/h$").control_input_ops("^.*/c$")
.output_ops([]))(self.f.op))
diff --git a/tensorflow/contrib/graph_editor/tests/reroute_test.py b/tensorflow/contrib/graph_editor/tests/reroute_test.py
index d663c8839d..3c00304add 100644
--- a/tensorflow/contrib/graph_editor/tests/reroute_test.py
+++ b/tensorflow/contrib/graph_editor/tests/reroute_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import graph_editor as ge
+from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@@ -41,25 +42,25 @@ class RerouteTest(test.TestCase):
def test_swap(self):
ge.swap_ts([self.a0, self.b0], [self.a1, self.b1])
- self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
- self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
+ self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
+ self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
def test_multiswap(self):
with self.graph.as_default():
a3 = constant_op.constant(3.0, shape=[2], name="a3")
ge.swap_ios(ge.sgv(a3.op).remap_outputs([0, 0]),
ge.sgv(self.a0.op, self.a1.op))
- self.assertTrue(ge.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
- self.assertTrue(ge.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
+ self.assertTrue(match.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
+ self.assertTrue(match.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
def test_reroute(self):
ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
- self.assertTrue(ge.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
- self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
+ self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
+ self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
- self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
- self.assertTrue(ge.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
+ self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
+ self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
def test_compatibility(self):
with self.assertRaises(ValueError):
@@ -84,9 +85,9 @@ class RerouteTest(test.TestCase):
ge.swap_outputs(sgv0, sgv1)
self.assertTrue(
- ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))(
- g.op))
- self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op))
+ match.OpMatcher("g").input_ops(
+ "a", match.OpMatcher("c").input_ops("a", "b"))(g.op))
+ self.assertTrue(match.OpMatcher("d").input_ops("e", "f")(d.op))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 33f1217412..a4105645c6 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import numpy as np
from tensorflow.contrib import graph_editor as ge
+from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@@ -100,12 +101,12 @@ class TransformTest(test.TestCase):
graph = ops.Graph()
transformer(self.graph, graph, "", "")
- matcher0 = ge.OpMatcher("AddNoise").input_ops(
- "Noise", ge.OpMatcher("Add").input_ops("Const", "Input"))
- matcher1 = ge.OpMatcher("AddNoise_1").input_ops(
- "Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0))
- matcher2 = ge.OpMatcher("AddNoise_2").input_ops(
- "Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1))
+ matcher0 = match.OpMatcher("AddNoise").input_ops(
+ "Noise", match.OpMatcher("Add").input_ops("Const", "Input"))
+ matcher1 = match.OpMatcher("AddNoise_1").input_ops(
+ "Noise_1", match.OpMatcher("Add_1").input_ops("Const_1", matcher0))
+ matcher2 = match.OpMatcher("AddNoise_2").input_ops(
+ "Noise_2", match.OpMatcher("Add_2").input_ops("Const_2", matcher1))
top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
self.assertTrue(matcher2(top))