aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/core/converter_testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/core/converter_testing.py')
-rw-r--r--tensorflow/contrib/autograph/core/converter_testing.py60
1 files changed, 36 insertions, 24 deletions
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py
index 0e46aacc12..2025e32817 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/contrib/autograph/core/converter_testing.py
@@ -20,19 +20,19 @@ from __future__ import print_function
import contextlib
import imp
+import sys
+
+import six
from tensorflow.contrib.autograph import operators
from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.core import errors
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import pretty_printer
-from tensorflow.contrib.autograph.pyct import qual_names
from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
from tensorflow.python.platform import test
@@ -74,7 +74,17 @@ class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities."""
@contextlib.contextmanager
- def compiled(self, node, *symbols):
+ def assertPrints(self, expected_result):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ yield
+ self.assertEqual(out_capturer.getvalue(), expected_result)
+ finally:
+ sys.stdout = sys.__stdout__
+
+ @contextlib.contextmanager
+ def compiled(self, node, namespace, *symbols):
source = None
self.dynamic_calls = []
@@ -89,7 +99,11 @@ class TestCase(test.TestCase):
fake_ag = self.make_fake_mod('fake_ag', converted_call)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
+ fake_ag.__dict__['rewrite_graph_construction_error'] = (
+ errors.rewrite_graph_construction_error)
result.__dict__['ag__'] = fake_ag
+ for k, v in namespace.items():
+ result.__dict__[k] = v
yield result
except Exception: # pylint:disable=broad-except
if source is None:
@@ -98,6 +112,13 @@ class TestCase(test.TestCase):
print('Offending compiled code:\n%s' % source)
raise
+ @contextlib.contextmanager
+ def converted(self, entity, converter_module, namespace, *tf_symbols):
+ node, ctx = self.prepare(entity, namespace)
+ node = converter_module.transform(node, ctx)
+ with self.compiled(node, namespace, *tf_symbols) as result:
+ yield result
+
def make_fake_mod(self, name, *symbols):
fake_mod = imp.new_module(name)
for s in symbols:
@@ -114,17 +135,15 @@ class TestCase(test.TestCase):
for k, v in ns.items():
setattr(module, k, v)
- def parse_and_analyze(self,
- test_fn,
- namespace,
- namer=None,
- arg_types=None,
- include_type_analysis=True,
- owner_type=None,
- recursive=True,
- autograph_decorators=()):
+ def prepare(self,
+ test_fn,
+ namespace,
+ namer=None,
+ arg_types=None,
+ owner_type=None,
+ recursive=True,
+ autograph_decorators=()):
node, source = parser.parse_entity(test_fn)
-
if namer is None:
namer = FakeNamer()
program_ctx = converter.ProgramContext(
@@ -141,12 +160,5 @@ class TestCase(test.TestCase):
arg_types=arg_types,
owner_type=owner_type)
ctx = converter.EntityContext(namer, entity_info, program_ctx)
-
- node = qual_names.resolve(node)
- node = activity.resolve(node, entity_info)
- node = live_values.resolve(node, entity_info, {})
- if include_type_analysis:
- node = type_info.resolve(node, entity_info)
- node = live_values.resolve(node, entity_info, {})
- self.ctx = ctx
- return node
+ node = converter.standard_analysis(node, ctx, is_initial=True)
+ return node, ctx