diff options
Diffstat (limited to 'tensorflow/contrib/autograph/core/converter_testing.py')
-rw-r--r-- | tensorflow/contrib/autograph/core/converter_testing.py | 60 |
1 files changed, 36 insertions, 24 deletions
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py index 0e46aacc12..2025e32817 100644 --- a/tensorflow/contrib/autograph/core/converter_testing.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -20,19 +20,19 @@ from __future__ import print_function import contextlib import imp +import sys + +import six from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer -from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -74,7 +74,17 @@ class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager - def compiled(self, node, *symbols): + def assertPrints(self, expected_result): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + yield + self.assertEqual(out_capturer.getvalue(), expected_result) + finally: + sys.stdout = sys.__stdout__ + + @contextlib.contextmanager + def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] @@ -89,7 +99,11 @@ class TestCase(test.TestCase): fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils + fake_ag.__dict__['rewrite_graph_construction_error'] = ( + errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag + for k, v in namespace.items(): + result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: @@ -98,6 +112,13 @@ class TestCase(test.TestCase): print('Offending compiled code:\n%s' % source) raise + @contextlib.contextmanager + def converted(self, entity, converter_module, namespace, *tf_symbols): + node, ctx = self.prepare(entity, namespace) + node = converter_module.transform(node, ctx) + with self.compiled(node, namespace, *tf_symbols) as result: + yield result + def make_fake_mod(self, name, *symbols): fake_mod = imp.new_module(name) for s in symbols: @@ -114,17 +135,15 @@ class TestCase(test.TestCase): for k, v in ns.items(): setattr(module, k, v) - def parse_and_analyze(self, - test_fn, - namespace, - namer=None, - arg_types=None, - include_type_analysis=True, - owner_type=None, - recursive=True, - autograph_decorators=()): + def prepare(self, + test_fn, + namespace, + namer=None, + arg_types=None, + owner_type=None, + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( @@ -141,12 +160,5 @@ class TestCase(test.TestCase): arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) - - node = qual_names.resolve(node) - node = activity.resolve(node, entity_info) - node = live_values.resolve(node, entity_info, {}) - if include_type_analysis: - node = type_info.resolve(node, entity_info) - node = live_values.resolve(node, entity_info, {}) - self.ctx = ctx - return node + node = converter.standard_analysis(node, ctx, is_initial=True) + return node, ctx |