diff options
author | 2017-03-10 19:41:43 -0800 | |
---|---|---|
committer | 2017-03-10 20:10:15 -0800 | |
commit | 49ecd2a53dcc7aa57dfee9d669a1a4eff8c14fad (patch) | |
tree | 5873d864a407e99e506fc38c93979df82d7a71d5 /tensorflow/contrib/graph_editor/tests | |
parent | 2e7cc48e5cce0ff5429b2d9d0ac313ce70035605 (diff) |
Automated rollback of change 149812873
Change: 149823781
Diffstat (limited to 'tensorflow/contrib/graph_editor/tests')
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/edit_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/match.py | 158 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/match_test.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/reroute_test.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 13 |
5 files changed, 192 insertions, 30 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py index 8adaf84b42..2f669c5d20 100644 --- a/tensorflow/contrib/graph_editor/tests/edit_test.py +++ b/tensorflow/contrib/graph_editor/tests/edit_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import graph_editor as ge +from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -54,7 +55,7 @@ class EditTest(test.TestCase): ge.detach(sgv, control_ios=control_outputs) # make sure the detached graph is as expected. self.assertTrue( - ge.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) + match.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) def test_connect(self): """Test for ge.connect.""" @@ -66,13 +67,13 @@ class EditTest(test.TestCase): sgv = ge.sgv(x.op, y.op, z.op) ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0])) self.assertTrue( - ge.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) + match.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) def test_bypass(self): """Test for ge.bypass.""" ge.bypass(ge.sgv(self.f.op).remap_inputs([0])) self.assertTrue( - ge.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( + match.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( self.h.op)) diff --git a/tensorflow/contrib/graph_editor/tests/match.py b/tensorflow/contrib/graph_editor/tests/match.py new file mode 100644 index 0000000000..1bf482b6c2 --- /dev/null +++ b/tensorflow/contrib/graph_editor/tests/match.py @@ -0,0 +1,158 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple graph matching functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import string_types + +from tensorflow.contrib.graph_editor import select +from tensorflow.python.framework import ops as tf_ops + +__all__ = [ + "op_type", + "OpMatcher", +] + + +def _make_graph_match(graph_match): + """Convert to a OpMatcher instance.""" + if graph_match is None: + return None + if not isinstance(graph_match, OpMatcher): + graph_match = OpMatcher(graph_match) + return graph_match + + +def op_type(op_types, op=None): + """Check if an op is of the given type. + + Args: + op_types: tuple of strings containing the types to check against. + For instance: ("Add", "Const") + op: the operation to check (or None). + Returns: + if op is not None, return True if the op is of the correct type. + if op is None, return a lambda function which does the type checking. + """ + if isinstance(op_types, string_types): + op_types = (op_types) + if op is None: + return lambda op: op.node_def.op in op_types + else: + return op.node_def.op in op_types + + +class OpMatcher(object): + """Graph match class.""" + + def __init__(self, positive_filter): + """Graph match constructor.""" + self.positive_filters = [] + self.input_op_matches = None + self.control_input_op_matches = None + self.output_op_matches = None + positive_filter = self._finalize_positive_filter(positive_filter) + self.positive_filters.append(positive_filter) + + def _finalize_positive_filter(self, elem): + """Convert to a filter function.""" + if select.can_be_regex(elem): + regex_ = select.make_regex(elem) + return lambda op, regex=regex_: regex.search(op.name) is not None + elif isinstance(elem, tf_ops.Operation): + return lambda op, match_op=elem: op is match_op + elif callable(elem): + return elem + elif elem is True: + return lambda op: True + else: + raise ValueError("Cannot finalize the positive filter: {}".format(elem)) + + def __call__(self, op): + """Evaluate if the op matches or not.""" + if not isinstance(op, tf_ops.Operation): + raise TypeError("Expect tf.Operation, got: {}".format(type(op))) + for positive_filter in self.positive_filters: + if not positive_filter(op): + return False + if self.input_op_matches is not None: + if len(op.inputs) != len(self.input_op_matches): + return False + for input_t, input_op_match in zip(op.inputs, self.input_op_matches): + if input_op_match is None: + continue + if not input_op_match(input_t.op): + return False + if self.control_input_op_matches is not None: + if len(op.control_inputs) != len(self.control_input_op_matches): + return False + for cinput_op, cinput_op_match in zip(op.control_inputs, + self.control_input_op_matches): + if cinput_op_match is None: + continue + if not cinput_op_match(cinput_op): + return False + if self.output_op_matches is not None: + if len(op.outputs) != len(self.output_op_matches): + return False + for output_t, output_op_matches in zip(op.outputs, + self.output_op_matches): + if output_op_matches is None: + continue + if len(output_t.consumers()) != len(output_op_matches): + return False + for consumer_op, consumer_op_match in zip(output_t.consumers(), + output_op_matches): + if consumer_op_match is None: + continue + if not consumer_op_match(consumer_op): + return False + return True + + def input_ops(self, *args): + """Add input matches.""" + if self.input_op_matches is not None: + raise ValueError("input_op_matches is already set.") + self.input_op_matches = [] + for input_match in args: + self.input_op_matches.append(_make_graph_match(input_match)) + return self + + def control_input_ops(self, *args): + """Add input matches.""" + if self.control_input_op_matches is not None: + raise ValueError("control_input_op_matches is already set.") + self.control_input_op_matches = [] + for input_match in args: + self.control_input_op_matches.append(_make_graph_match(input_match)) + return self + + def output_ops(self, *args): + """Add output matches.""" + if self.output_op_matches is not None: + raise ValueError("output_op_matches is already set.") + self.output_op_matches = [] + for consumer_op_matches in args: + if consumer_op_matches is None: + self.output_op_matches.append(None) + if not isinstance(consumer_op_matches, list): + consumer_op_matches = [consumer_op_matches] + consumer_op_matches = [_make_graph_match(consumer_op_match) + for consumer_op_match in consumer_op_matches] + self.output_op_matches.append(consumer_op_matches) + return self diff --git a/tensorflow/contrib/graph_editor/tests/match_test.py b/tensorflow/contrib/graph_editor/tests/match_test.py index bcb8f3f0e3..d81dc34dba 100644 --- a/tensorflow/contrib/graph_editor/tests/match_test.py +++ b/tensorflow/contrib/graph_editor/tests/match_test.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import graph_editor as ge +from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -42,20 +42,21 @@ class MatchTest(test.TestCase): self.h = math_ops.add(self.f, self.g, name="h") def test_simple_match(self): - self.assertTrue(ge.OpMatcher("^.*/f$")(self.f.op)) + self.assertTrue(match.OpMatcher("^.*/f$")(self.f.op)) self.assertTrue( - ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) - self.assertTrue(ge.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) + match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) self.assertTrue( - ge.OpMatcher("^.*/f$").input_ops( - ge.match.op_type("Add"), ge.match.op_type("Const"))(self.f.op)) + match.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) self.assertTrue( - ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") - .output_ops(ge.OpMatcher("^.*/h$") + match.OpMatcher("^.*/f$").input_ops( + match.op_type("Add"), match.op_type("Const"))(self.f.op)) + self.assertTrue( + match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") + .output_ops(match.OpMatcher("^.*/h$") .control_input_ops("^.*/c$"))(self.f.op)) self.assertTrue( - ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( - ge.OpMatcher("^.*/h$").control_input_ops("^.*/c$") + match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( + match.OpMatcher("^.*/h$").control_input_ops("^.*/c$") .output_ops([]))(self.f.op)) diff --git a/tensorflow/contrib/graph_editor/tests/reroute_test.py b/tensorflow/contrib/graph_editor/tests/reroute_test.py index d663c8839d..3c00304add 100644 --- a/tensorflow/contrib/graph_editor/tests/reroute_test.py +++ b/tensorflow/contrib/graph_editor/tests/reroute_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import graph_editor as ge +from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -41,25 +42,25 @@ class RerouteTest(test.TestCase): def test_swap(self): ge.swap_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) + self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) def test_multiswap(self): with self.graph.as_default(): a3 = constant_op.constant(3.0, shape=[2], name="a3") ge.swap_ios(ge.sgv(a3.op).remap_outputs([0, 0]), ge.sgv(self.a0.op, self.a1.op)) - self.assertTrue(ge.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) - self.assertTrue(ge.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) + self.assertTrue(match.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) + self.assertTrue(match.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) def test_reroute(self): ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(ge.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) - self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) + self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) + self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) - self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(ge.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) + self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) def test_compatibility(self): with self.assertRaises(ValueError): @@ -84,9 +85,9 @@ class RerouteTest(test.TestCase): ge.swap_outputs(sgv0, sgv1) self.assertTrue( - ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))( - g.op)) - self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op)) + match.OpMatcher("g").input_ops( + "a", match.OpMatcher("c").input_ops("a", "b"))(g.op)) + self.assertTrue(match.OpMatcher("d").input_ops("e", "f")(d.op)) if __name__ == "__main__": diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 33f1217412..a4105645c6 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import numpy as np from tensorflow.contrib import graph_editor as ge +from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -100,12 +101,12 @@ class TransformTest(test.TestCase): graph = ops.Graph() transformer(self.graph, graph, "", "") - matcher0 = ge.OpMatcher("AddNoise").input_ops( - "Noise", ge.OpMatcher("Add").input_ops("Const", "Input")) - matcher1 = ge.OpMatcher("AddNoise_1").input_ops( - "Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0)) - matcher2 = ge.OpMatcher("AddNoise_2").input_ops( - "Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1)) + matcher0 = match.OpMatcher("AddNoise").input_ops( + "Noise", match.OpMatcher("Add").input_ops("Const", "Input")) + matcher1 = match.OpMatcher("AddNoise_1").input_ops( + "Noise_1", match.OpMatcher("Add_1").input_ops("Const_1", matcher0)) + matcher2 = match.OpMatcher("AddNoise_2").input_ops( + "Noise_2", match.OpMatcher("Add_2").input_ops("Const_2", matcher1)) top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) |