aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/common
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-17 10:15:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 10:18:25 -0700
commit151e35680b0b2575aa8bdb6bddbb95536be4fed0 (patch)
tree069cd44db3c97ddcab54cea2b77f65c6d4fdfb7b /tensorflow/tools/common
parentf4162b7eafcb9c27292a9544b1a29f2fc7f54be6 (diff)
Change traverse_test.test_module to traverse a constructed dummy module rather than testcase itself.
PiperOrigin-RevId: 197010681
Diffstat (limited to 'tensorflow/tools/common')
-rw-r--r--tensorflow/tools/common/BUILD17
-rw-r--r--tensorflow/tools/common/test_module1.py31
-rw-r--r--tensorflow/tools/common/test_module2.py29
-rw-r--r--tensorflow/tools/common/traverse_test.py15
4 files changed, 82 insertions, 10 deletions
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
index b9032c046e..8c01d15a80 100644
--- a/tensorflow/tools/common/BUILD
+++ b/tensorflow/tools/common/BUILD
@@ -40,7 +40,24 @@ py_test(
srcs = ["traverse_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":test_module1",
+ ":test_module2",
":traverse",
"//tensorflow/python:platform_test",
],
)
+
+py_library(
+ name = "test_module1",
+ srcs = ["test_module1.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":test_module2",
+ ],
+)
+
+py_library(
+ name = "test_module2",
+ srcs = ["test_module2.py"],
+ srcs_version = "PY2AND3",
+)
diff --git a/tensorflow/tools/common/test_module1.py b/tensorflow/tools/common/test_module1.py
new file mode 100644
index 0000000000..cc185cf36e
--- /dev/null
+++ b/tensorflow/tools/common/test_module1.py
@@ -0,0 +1,31 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A module target for TraverseTest.test_module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.tools.common import test_module2
+
+
+class ModuleClass1(object):
+
+ def __init__(self):
+ self._m2 = test_module2.ModuleClass2()
+
+ def __model_class1_method__(self):
+ pass
+
diff --git a/tensorflow/tools/common/test_module2.py b/tensorflow/tools/common/test_module2.py
new file mode 100644
index 0000000000..d9da99d9c0
--- /dev/null
+++ b/tensorflow/tools/common/test_module2.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A module target for TraverseTest.test_module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class ModuleClass2(object):
+
+ def __init__(self):
+ pass
+
+ def __model_class1_method__(self):
+ pass
+
diff --git a/tensorflow/tools/common/traverse_test.py b/tensorflow/tools/common/traverse_test.py
index eb195ec18e..ed410694ce 100644
--- a/tensorflow/tools/common/traverse_test.py
+++ b/tensorflow/tools/common/traverse_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys
-
from tensorflow.python.platform import googletest
+from tensorflow.tools.common import test_module1
+from tensorflow.tools.common import test_module2
from tensorflow.tools.common import traverse
@@ -30,10 +30,6 @@ class TestVisitor(object):
self.call_log = []
def __call__(self, path, parent, children):
- # Do not traverse googletest, it's very deep.
- for item in list(children):
- if item[1] is googletest:
- children.remove(item)
self.call_log += [(path, parent, children)]
@@ -51,13 +47,12 @@ class TraverseTest(googletest.TestCase):
def test_module(self):
visitor = TestVisitor()
- traverse.traverse(sys.modules[__name__], visitor)
+ traverse.traverse(test_module1, visitor)
called = [parent for _, parent, _ in visitor.call_log]
- self.assertIn(TestVisitor, called)
- self.assertIn(TraverseTest, called)
- self.assertIn(traverse, called)
+ self.assertIn(test_module1.ModuleClass1, called)
+ self.assertIn(test_module2.ModuleClass2, called)
def test_class(self):
visitor = TestVisitor()