diff options
author | Dandelion Mané <dandelion@google.com> | 2017-03-10 16:41:15 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-10 17:08:07 -0800 |
commit | ba7dd8ae0ae01a94826f4e21e38d6e5d12979915 (patch) | |
tree | aac24499e6c4d586b1288933aad30d67b37173da /tensorflow/contrib/graph_editor | |
parent | fc112a6b53d782eacb46eb357a8720d6b5a5d3cc (diff) |
Automated rollback of change 149741516
Change: 149812873
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/match.py (renamed from tensorflow/contrib/graph_editor/tests/match.py) | 0 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/edit_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/match_test.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/reroute_test.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 13 |
7 files changed, 32 insertions, 34 deletions
diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 18ee568c16..28e367463c 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -14,6 +14,7 @@ py_library( srcs = [ "__init__.py", "edit.py", + "match.py", "reroute.py", "select.py", "subgraph.py", diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py index 51b7f45274..49111d5437 100644 --- a/tensorflow/contrib/graph_editor/__init__.py +++ b/tensorflow/contrib/graph_editor/__init__.py @@ -23,6 +23,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.graph_editor.edit import * +from tensorflow.contrib.graph_editor.match import * from tensorflow.contrib.graph_editor.reroute import * from tensorflow.contrib.graph_editor.select import * from tensorflow.contrib.graph_editor.subgraph import * diff --git a/tensorflow/contrib/graph_editor/tests/match.py b/tensorflow/contrib/graph_editor/match.py index 1bf482b6c2..1bf482b6c2 100644 --- a/tensorflow/contrib/graph_editor/tests/match.py +++ b/tensorflow/contrib/graph_editor/match.py diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py index 2f669c5d20..8adaf84b42 100644 --- a/tensorflow/contrib/graph_editor/tests/edit_test.py +++ b/tensorflow/contrib/graph_editor/tests/edit_test.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import graph_editor as ge -from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -55,7 +54,7 @@ class EditTest(test.TestCase): ge.detach(sgv, control_ios=control_outputs) # make sure the detached graph is as expected. self.assertTrue( - match.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) + ge.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) def test_connect(self): """Test for ge.connect.""" @@ -67,13 +66,13 @@ class EditTest(test.TestCase): sgv = ge.sgv(x.op, y.op, z.op) ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0])) self.assertTrue( - match.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) + ge.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) def test_bypass(self): """Test for ge.bypass.""" ge.bypass(ge.sgv(self.f.op).remap_inputs([0])) self.assertTrue( - match.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( + ge.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( self.h.op)) diff --git a/tensorflow/contrib/graph_editor/tests/match_test.py b/tensorflow/contrib/graph_editor/tests/match_test.py index d81dc34dba..bcb8f3f0e3 100644 --- a/tensorflow/contrib/graph_editor/tests/match_test.py +++ b/tensorflow/contrib/graph_editor/tests/match_test.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.graph_editor.tests import match +from tensorflow.contrib import graph_editor as ge from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -42,21 +42,20 @@ class MatchTest(test.TestCase): self.h = math_ops.add(self.f, self.g, name="h") def test_simple_match(self): - self.assertTrue(match.OpMatcher("^.*/f$")(self.f.op)) + self.assertTrue(ge.OpMatcher("^.*/f$")(self.f.op)) self.assertTrue( - match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) + self.assertTrue(ge.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) self.assertTrue( - match.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) + ge.OpMatcher("^.*/f$").input_ops( + ge.match.op_type("Add"), ge.match.op_type("Const"))(self.f.op)) self.assertTrue( - match.OpMatcher("^.*/f$").input_ops( - match.op_type("Add"), match.op_type("Const"))(self.f.op)) - self.assertTrue( - match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") - .output_ops(match.OpMatcher("^.*/h$") + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") + .output_ops(ge.OpMatcher("^.*/h$") .control_input_ops("^.*/c$"))(self.f.op)) self.assertTrue( - match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( - match.OpMatcher("^.*/h$").control_input_ops("^.*/c$") + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( + ge.OpMatcher("^.*/h$").control_input_ops("^.*/c$") .output_ops([]))(self.f.op)) diff --git a/tensorflow/contrib/graph_editor/tests/reroute_test.py b/tensorflow/contrib/graph_editor/tests/reroute_test.py index 3c00304add..d663c8839d 100644 --- a/tensorflow/contrib/graph_editor/tests/reroute_test.py +++ b/tensorflow/contrib/graph_editor/tests/reroute_test.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import graph_editor as ge -from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -42,25 +41,25 @@ class RerouteTest(test.TestCase): def test_swap(self): ge.swap_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) + self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) def test_multiswap(self): with self.graph.as_default(): a3 = constant_op.constant(3.0, shape=[2], name="a3") ge.swap_ios(ge.sgv(a3.op).remap_outputs([0, 0]), ge.sgv(self.a0.op, self.a1.op)) - self.assertTrue(match.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) - self.assertTrue(match.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) + self.assertTrue(ge.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) def test_reroute(self): ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) - self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) + self.assertTrue(ge.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) - self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) + self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) def test_compatibility(self): with self.assertRaises(ValueError): @@ -85,9 +84,9 @@ class RerouteTest(test.TestCase): ge.swap_outputs(sgv0, sgv1) self.assertTrue( - match.OpMatcher("g").input_ops( - "a", match.OpMatcher("c").input_ops("a", "b"))(g.op)) - self.assertTrue(match.OpMatcher("d").input_ops("e", "f")(d.op)) + ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))( + g.op)) + self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op)) if __name__ == "__main__": diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index a4105645c6..33f1217412 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import numpy as np from tensorflow.contrib import graph_editor as ge -from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -101,12 +100,12 @@ class TransformTest(test.TestCase): graph = ops.Graph() transformer(self.graph, graph, "", "") - matcher0 = match.OpMatcher("AddNoise").input_ops( - "Noise", match.OpMatcher("Add").input_ops("Const", "Input")) - matcher1 = match.OpMatcher("AddNoise_1").input_ops( - "Noise_1", match.OpMatcher("Add_1").input_ops("Const_1", matcher0)) - matcher2 = match.OpMatcher("AddNoise_2").input_ops( - "Noise_2", match.OpMatcher("Add_2").input_ops("Const_2", matcher1)) + matcher0 = ge.OpMatcher("AddNoise").input_ops( + "Noise", ge.OpMatcher("Add").input_ops("Const", "Input")) + matcher1 = ge.OpMatcher("AddNoise_1").input_ops( + "Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0)) + matcher2 = ge.OpMatcher("AddNoise_2").input_ops( + "Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1)) top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) |