From b24037513f12a5812a21b7ea92ff904ee9ea6cd8 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 25 Jul 2018 18:01:50 -0700 Subject: Fix bug in parallel_walk, along with a few structural bugs that this fix revealed: 1. The conversion process was inconsistently packaging the final output into modules or lists. This CL uniformly uses a list of nodes as output from all *_to_graph functions. As a side effect, converter_testing.py asserts that the output is always a single node and extracts it, so there is no need for tests to unpack it any more. Modify the compiler to skip generating a source map by default. 2. The class converter was incorrectly saving the superclass value to the string 'object' instead of the symbol `object`. Additional refactoring that was caught along: Simplify the source mapping code, move it to origin_info.py, add tests and additional checks. Slightly simplify the error rewriting mechanism. PiperOrigin-RevId: 206087110 --- .../contrib/autograph/converters/asserts_test.py | 2 +- .../autograph/converters/directives_test.py | 6 +- .../autograph/converters/error_handlers_test.py | 6 +- .../contrib/autograph/converters/lists_test.py | 4 +- .../converters/side_effect_guards_test.py | 12 +- .../contrib/autograph/converters/slices_test.py | 6 +- .../contrib/autograph/core/converter_testing.py | 4 +- tensorflow/contrib/autograph/core/errors.py | 116 ++++++++--------- tensorflow/contrib/autograph/core/errors_test.py | 108 ++++++++-------- tensorflow/contrib/autograph/impl/api.py | 21 ++-- tensorflow/contrib/autograph/impl/conversion.py | 34 ++--- .../contrib/autograph/impl/conversion_test.py | 29 ++--- tensorflow/contrib/autograph/pyct/BUILD | 10 ++ tensorflow/contrib/autograph/pyct/ast_util.py | 87 ++++++++----- tensorflow/contrib/autograph/pyct/ast_util_test.py | 62 ++++++---- tensorflow/contrib/autograph/pyct/cfg.py | 6 +- .../autograph/pyct/common_transformers/anf_test.py | 2 +- tensorflow/contrib/autograph/pyct/compiler.py | 127 ++++++++----------- tensorflow/contrib/autograph/pyct/compiler_test.py | 2 +- tensorflow/contrib/autograph/pyct/origin_info.py | 137 ++++++++++++++++----- .../contrib/autograph/pyct/origin_info_test.py | 101 +++++++++++++++ tensorflow/contrib/autograph/pyct/parser.py | 1 + 22 files changed, 539 insertions(+), 344 deletions(-) create mode 100644 tensorflow/contrib/autograph/pyct/origin_info_test.py diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index 9c58ae3acc..38faba45df 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -35,7 +35,7 @@ class AssertsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = asserts.transform(node, ctx) - self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) + self.assertTrue(isinstance(node.body[0].value, gast.Call)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py index 5f798a5b76..a573ba5850 100644 --- a/tensorflow/contrib/autograph/converters/directives_test.py +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -38,7 +38,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].s, 'a') @@ -52,7 +52,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].n, 1) self.assertEqual(d['shape'].n, 2) @@ -67,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - d = anno.getanno(node.body[0].body[1], AgAnno.DIRECTIVES) + d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py index 878526c8b4..cd74e5f18f 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers_test.py +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -34,11 +34,13 @@ class ErrorHandlersTest(converter_testing.TestCase): raise ValueError() node, ctx = self.prepare(test_fn, {}) - anno.setanno(node.body[0], anno.Basic.ORIGIN, - origin_info.OriginInfo('test_path', None, None, None, None)) + anno.setanno(node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, None, None)) node = error_handlers.transform(node, ctx) with self.compiled(node, {}) as result: with self.assertRaises(errors.GraphConstructionError): + # Here we just assert that the handler works. Its correctness is + # verified by errors_test.py. result.test_fn() def test_no_origin_annotation(self): diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index f906918ac0..996e99ee61 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -79,7 +79,7 @@ class ListTest(converter_testing.TestCase): ns = {'special_functions': special_functions} node, ctx = self.prepare(test_fn, ns) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.ORIG_DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32'), @@ -114,7 +114,7 @@ class ListTest(converter_testing.TestCase): return tf.stack(l) node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.ORIG_DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index de1874321e..bee512abbc 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -43,7 +43,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: @@ -64,7 +64,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: @@ -84,7 +84,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, control_flow_ops.Assert) as result: with self.test_session() as sess: @@ -104,7 +104,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign_add) as result: with self.test_session() as sess: @@ -125,7 +125,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body[0].body), 1) + self.assertEqual(len(node.body[0].body), 1) with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: with self.test_session() as sess: @@ -147,7 +147,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign, state_ops.assign_add) as result: diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py index 3c0f81e8bc..c822d53a4a 100644 --- a/tensorflow/contrib/autograph/converters/slices_test.py +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -38,7 +38,7 @@ class SliceTest(converter_testing.TestCase): return l[1] node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') } @@ -59,11 +59,11 @@ class SliceTest(converter_testing.TestCase): return l[1] node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') } - def_, = anno.getanno(node.body[0].body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].body[0].targets[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.float32') diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py index 2025e32817..5ee2c3fffd 100644 --- a/tensorflow/contrib/autograph/core/converter_testing.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -94,7 +94,8 @@ class TestCase(test.TestCase): return 7 try: - result, source = compiler.ast_to_object(node) + result, source = compiler.ast_to_object(node, include_source_map=True) + result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) @@ -144,6 +145,7 @@ class TestCase(test.TestCase): recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) + node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py index e58745337a..c219b372c1 100644 --- a/tensorflow/contrib/autograph/core/errors.py +++ b/tensorflow/contrib/autograph/core/errors.py @@ -31,11 +31,14 @@ import logging import sys import traceback -from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.python.framework import errors_impl from tensorflow.python.util import tf_inspect +# TODO(mdan): Add a superclass common to all errors. + + class GraphConstructionError(Exception): """Error for graph construction errors from AutoGraph generated code.""" @@ -65,27 +68,35 @@ class TfRuntimeError(Exception): return message + ''.join(traceback.format_list(self.custom_traceback)) -def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices): - """Rewrites the stack frames at the given indices using the given source map. +def _rewrite_tb(source_map, tb, filter_function_name=None): + """Rewrites code references in a traceback. Args: - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. - cleaned_traceback: List[Tuple[text, text, text, text]], the current - traceback. - stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if - there are matching source mapping keys. - + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + tb: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb + filter_function_name: Optional[Text], allows restricting restricts the + frames to rewrite to a particular function name Returns: - None + List[Tuple[Text, Text, Text, Text]], the rewritten traceback """ - for frame_index in stack_frame_indices: - # (file_path, line number, function name, code) - file_path, line_number, _, _ = cleaned_traceback[frame_index] - source_map_key = CodeLocation(file_path=file_path, line_number=line_number) - found_mapping = source_map_key in source_map - if found_mapping: - cleaned_traceback[frame_index] = source_map[source_map_key].as_frame() + new_tb = [] + for frame in tb: + filename, lineno, function_name, _ = frame + loc = origin_info.LineLocation(filename, lineno) + origin = source_map.get(loc) + # TODO(mdan): We shouldn't need the function name at all. + # filename + lineno should be sufficient, even if there are multiple source + # maps. + if origin is not None: + if filter_function_name == function_name or filter_function_name is None: + new_tb.append(origin.as_frame()) + else: + new_tb.append(frame) + else: + new_tb.append(frame) + return new_tb # TODO(znado): Make more robust to name changes in the rewriting logic. @@ -98,18 +109,20 @@ def _remove_rewrite_frames(tb): return cleaned_tb +# TODO(mdan): rename to raise_* def rewrite_graph_construction_error(source_map): """Rewrites errors raised by non-AG APIs inside AG generated code. - Meant to be called from the try/except block inside each AutoGraph generated - function. Only rewrites the traceback frames corresponding to the function - that this is called from. When we raise a GraphConstructionError at the end - it is then caught by calling functions, where they can be responsible for - rewriting their own frames. + This is called from the except handler inside an AutoGraph generated function + (that is, during exception handling). Only rewrites the frames corresponding + to the function that this is called from, so each function is responsible + to call this to have its own frames rewritten. + + This function always raises an error. Args: - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. + source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source + map belonging to the calling function Raises: GraphConstructionError: The rewritten underlying error. @@ -120,31 +133,19 @@ def rewrite_graph_construction_error(source_map): assert original_error is not None try: _, _, _, func_name, _, _ = tf_inspect.stack()[1] - # The latest function call is added to the beginning of a traceback, but - # when rewriting the traceback of multiple function calls the previous - # functions' except blocks may have already rewritten their own frames so - # we want to copy over all of the previous frames. We may have rewritten - # previous frames only if the error is a GraphConstructionError. if isinstance(original_error, GraphConstructionError): + # TODO(mdan): This is incomplete. + # The error might have bubbled through a non-converted function. cleaned_traceback = traceback.extract_tb(e_traceback) previous_traceback = original_error.custom_traceback cleaned_traceback = [cleaned_traceback[0]] + previous_traceback else: cleaned_traceback = traceback.extract_tb(e_traceback) - cleaned_traceback = _remove_rewrite_frames(cleaned_traceback) - - current_frame_indices = [] - # This code is meant to be called from the try/except block that wraps a - # function body. Here we look for all frames that came from the function - # that this wraps, look for any matching line numbers in the source - # mapping, and then rewrite them if matches are found. - for fi, frame in enumerate(cleaned_traceback): - _, _, frame_func_name, _ = frame - if frame_func_name == func_name: - current_frame_indices.append(fi) - break - if current_frame_indices: - _rewrite_frame(source_map, cleaned_traceback, current_frame_indices) + + # Remove the frame corresponding to this function call. + cleaned_traceback = cleaned_traceback[1:] + + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name) if isinstance(original_error, GraphConstructionError): original_error.custom_traceback = cleaned_traceback @@ -153,6 +154,7 @@ def rewrite_graph_construction_error(source_map): new_error = GraphConstructionError(original_error, cleaned_traceback) except Exception: logging.exception('Error while rewriting AutoGraph error:') + # TODO(mdan): Should reraise here, removing the top frame as well. raise original_error else: raise new_error @@ -161,18 +163,17 @@ def rewrite_graph_construction_error(source_map): del e_traceback +# TODO(mdan): This should be consistent with rewrite_graph_construction_error +# Both should either raise or return. def rewrite_tf_runtime_error(error, source_map): """Rewrites TensorFlow runtime errors raised by ops created in AG code. Args: - error: error_impl.OpError, an TensorFlow error that will have its traceback - rewritten. - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. + error: tf.OpError + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo] Returns: - A TfRuntimeError with a traceback rewritten according to the given - source mapping. + TfRuntimeError, the rewritten underlying error. """ # Check for cases where we leave a user method and re-enter it in the # traceback. This is done by looking at the function names when the @@ -198,15 +199,16 @@ def rewrite_tf_runtime_error(error, source_map): # The source map keys are (file_path, line_number) so get the set of all user # file_paths. try: - all_user_files = set(k.file_path for k in source_map) + all_user_files = set(loc.filename for loc in source_map) cleaned_traceback = [] last_user_frame_index = None last_user_user_file_path = None last_user_user_fn_name = None + # TODO(mdan): Simplify this logic. for fi, frame in enumerate(error.op.traceback): - frame_file_path, frame_line_number, _, _ = frame - src_map_key = CodeLocation( - file_path=frame_file_path, line_number=frame_line_number) + frame_file_path, lineno, _, _ = frame + lineno -= 1 # Frame line numbers are 1-based. + src_map_key = origin_info.LineLocation(frame_file_path, lineno) if frame_file_path in all_user_files: if src_map_key in source_map: original_fn_name = source_map[src_map_key].function_name @@ -223,8 +225,8 @@ def rewrite_tf_runtime_error(error, source_map): last_user_user_file_path = frame_file_path cleaned_traceback.append(frame) - for fi in range(len(cleaned_traceback)): - _rewrite_frame(source_map, cleaned_traceback, [fi]) + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + op_name = error.op.name op_message = error.message rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback) @@ -263,7 +265,7 @@ def improved_errors(converted_function): ValueError: If converted_function is not generated by AutoGraph """ if (getattr(converted_function, 'ag_source_map', None) is None or - not converted_function.ag_source_map): + not isinstance(converted_function.ag_source_map, dict)): raise ValueError( 'converted_function must be the result of an autograph.to_graph call') try: diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py index 7be54563a1..c0e2c74e47 100644 --- a/tensorflow/contrib/autograph/core/errors_test.py +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -28,88 +28,76 @@ from tensorflow.python.util import tf_inspect def zero_div(): - return array_ops.constant(10, dtype=dtypes.int32) // 0 + x = array_ops.constant(10, dtype=dtypes.int32) + return x // 0 def zero_div_caller(): - a = zero_div() + 2 - return a + return zero_div() class RuntimeErrorsTest(test.TestCase): - def setUp(self): - self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0, - 'print("hello world!")') - - def test_error_replacement(self): - _, zero_div_lineno = tf_inspect.getsourcelines(zero_div) - src_map = { - errors.CodeLocation( - file_path=__file__, line_number=zero_div_lineno + 1): - self._fake_origin - } + def fake_origin(self, function, line_offset): + _, lineno = tf_inspect.getsourcelines(function) + filename = tf_inspect.getsourcefile(function) + lineno += line_offset + loc = origin_info.LineLocation(filename, lineno) + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code') + return loc, origin + + def test_improved_errors_basic(self): + loc, origin = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() with self.assertRaises(errors.TfRuntimeError) as cm: - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) - expected = cm.exception - current_traceback = expected.custom_traceback - for frame in current_traceback: - self.assertNotEqual('zero_div', frame[2]) - self.assertTrue( - any(self._fake_origin.as_frame() == frame - for frame in current_traceback)) - - def test_error_not_found(self): - src_map = { - errors.CodeLocation(file_path=__file__, line_number=-1): - self._fake_origin - } + sess.run(ops) + + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + self.assertNotEqual('zero_div', function_name) + self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback)) + + def test_improved_errors_no_matching_lineno(self): + loc, origin = self.fake_origin(zero_div, -1) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() with self.assertRaises(errors.TfRuntimeError) as cm: - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) - expected = cm.exception - current_traceback = expected.custom_traceback - self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback)) - for frame in current_traceback: - self.assertNotEqual(frame, self._fake_origin.as_frame()) - - def test_rewriting_error(self): - _, zero_div_lineno = tf_inspect.getsourcelines(zero_div) - src_map = { - errors.CodeLocation( - file_path=__file__, line_number=zero_div_lineno + 1): - None - } - with self.assertRaisesRegexp(tf_errors.InvalidArgumentError, - 'Integer division by zero'): - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map + sess.run(ops) + + all_function_names = set() + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + all_function_names.add(function_name) + self.assertNotEqual('test_function_name', function_name) + self.assertIn('zero_div', all_function_names) + + def test_improved_errors_failures(self): + loc, _ = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: 'bogus object'} + + ops = zero_div_caller() + with self.assertRaises(tf_errors.InvalidArgumentError): with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) + sess.run(ops) - def test_no_ag_source_map(self): + def test_improved_errors_validation(self): with self.assertRaisesRegexp( ValueError, 'converted_function must be the result of an autograph.to_graph call'): - with errors.improved_errors(None): - pass - - def test_bad_ag_source_map(self): + errors.improved_errors(zero_div).__enter__() with self.assertRaisesRegexp( ValueError, 'converted_function must be the result of an autograph.to_graph call'): - src_map = None - zero_div_caller.ag_source_map = src_map - with errors.improved_errors(None): - pass + zero_div_caller.ag_source_map = 'not a dict' + errors.improved_errors(zero_div_caller).__enter__() if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index f7fe3de5da..ee71f4f9ac 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -23,7 +23,6 @@ from functools import wraps from enum import Enum # pylint:disable=g-bad-import-order -import gast import six # pylint:enable=g-bad-import-order @@ -245,19 +244,21 @@ def to_graph(e, _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - module = gast.Module([]) + nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): - module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object( - module, source_prefix=program_ctx.required_imports) + nodes.extend(dep) + compiled_module, compiled_src = compiler.ast_to_object( + nodes, + source_prefix=program_ctx.required_imports, + include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. - if key not in compiled_node.__dict__: - compiled_node.__dict__[key] = val - compiled_fn = getattr(compiled_node, name) + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val + compiled_fn = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. @@ -270,7 +271,7 @@ def to_graph(e, '"%s", which is reserved for AutoGraph.' % (compiled_fn, source_map_attribute_name)) setattr(compiled_fn, source_map_attribute_name, - compiled_node.__dict__['ag_source_map__']) + compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) @@ -308,7 +309,7 @@ def to_code(e, conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( - compiler.ast_to_source(dep, indentation)[0] + compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 7bd0ba3f2d..57ec739a80 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -164,7 +164,7 @@ def class_to_graph(c, program_ctx): class_namespace = namespace else: class_namespace.update(namespace) - converted_members[m] = node + converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) @@ -175,10 +175,10 @@ def class_to_graph(c, program_ctx): # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} - bases = [] + base_names = [] for base in c.__bases__: if isinstance(object, base): - bases.append('object') + base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) @@ -190,28 +190,28 @@ def class_to_graph(c, program_ctx): else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) - bases.append(alias) + base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. - output_nodes.append( - gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[])) - node = gast.Module(output_nodes) - + bases = [gast.Name(n, gast.Load(), None) for n in base_names] + class_def = gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. - node = qual_names.resolve(node) + class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - node = ast_util.rename_symbols(node, renames) + class_def = ast_util.rename_symbols(class_def, renames) + + output_nodes.append(class_def) - return node, class_name, class_namespace + return output_nodes, class_name, class_namespace def _add_reserved_symbol(namespace, name, entity): @@ -279,7 +279,7 @@ def function_to_graph(f, program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - return node, new_name, namespace + return (node,), new_name, namespace def node_to_graph(node, context, rewrite_errors=True): diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 207225a1ac..bfc51365a3 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -60,10 +60,11 @@ class ConversionTest(test.TestCase): return a + b program_ctx = self._simple_program_ctx() - ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(isinstance(ast, gast.FunctionDef), ast) + nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) + fn_node, = nodes + self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) - self.assertTrue(ns['b'] is b) + self.assertIs(ns['b'], b) def test_entity_to_graph_call_tree(self): @@ -78,14 +79,11 @@ class ConversionTest(test.TestCase): self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) - self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) - # need one extra .body[0] in order to step past the try/except wrapper that - # is added automatically, the other for the with tf.name_scope('f') that is - # added automatically - self.assertEqual( - 'tf__g', - program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id) - self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) + f_node = program_ctx.dependency_cache[f][0] + g_node = program_ctx.dependency_cache[g][0] + self.assertEqual('tf__f', f_node.name) + self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) + self.assertEqual('tf__g', g_node.name) def test_entity_to_graph_class_hierarchy(self): @@ -118,9 +116,9 @@ class ConversionTest(test.TestCase): self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertEqual('TfTestBase', - program_ctx.dependency_cache[TestBase].body[-1].name) + program_ctx.dependency_cache[TestBase][-1].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-1].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -139,10 +137,9 @@ class ConversionTest(test.TestCase): self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( - 'Model', - program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-1].name) def test_entity_to_graph_lambda(self): f = lambda a: a diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index f77a6ab392..ddadc6b96e 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -99,6 +99,16 @@ py_test( ], ) +py_test( + name = "origin_info_test", + srcs = ["origin_info_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index 86e3f56a64..d7453b0781 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -20,7 +20,6 @@ from __future__ import print_function import ast -import collections import gast from tensorflow.contrib.autograph.pyct import anno @@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor): if v != p: return self.no_match() + def matches(node, pattern): """Basic pattern matcher for AST. @@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn): apply_fn(target, values) -def iter_fields(node): - for field in sorted(node._fields): - try: - yield getattr(node, field) - except AttributeError: - pass - - -def iter_child_nodes(node): - for field in iter_fields(node): - if isinstance(field, gast.AST): - yield field - elif isinstance(field, list): - for item in field: - if isinstance(item, gast.AST): - yield item - - -def parallel_walk(node_a, node_b): - todo_a = collections.deque([node_a]) - todo_b = collections.deque([node_b]) - while todo_a and todo_b: - node_a = todo_a.popleft() - node_b = todo_b.popleft() - todo_a.extend(iter_child_nodes(node_a)) - todo_b.extend(iter_child_nodes(node_b)) - yield node_a, node_b +def parallel_walk(node, other): + """Walks two ASTs in parallel. + + The two trees must have identical structure. + + Args: + node: Union[ast.AST, Iterable[ast.AST]] + other: Union[ast.AST, Iterable[ast.AST]] + Yields: + Tuple[ast.AST, ast.AST] + Raises: + ValueError: if the two trees don't have identical structure. + """ + if isinstance(node, (list, tuple)): + node_stack = list(node) + else: + node_stack = [node] + + if isinstance(other, (list, tuple)): + other_stack = list(other) + else: + other_stack = [other] + + while node_stack and other_stack: + assert len(node_stack) == len(other_stack) + n = node_stack.pop() + o = other_stack.pop() + + if (not isinstance(n, (ast.AST, gast.AST)) or + not isinstance(o, (ast.AST, gast.AST)) or + n.__class__.__name__ != o.__class__.__name__): + raise ValueError('inconsistent nodes: {} and {}'.format(n, o)) + + yield n, o + + for f in n._fields: + n_child = getattr(n, f, None) + o_child = getattr(o, f, None) + if f.startswith('__') or n_child is None or o_child is None: + continue + + if isinstance(n_child, (list, tuple)): + if (not isinstance(o_child, (list, tuple)) or + len(n_child) != len(o_child)): + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) + node_stack.extend(n_child) + other_stack.extend(o_child) + + elif isinstance(n_child, (gast.AST, ast.AST)): + node_stack.append(n_child) + other_stack.append(o_child) + + elif n_child != o_child: + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index 981e398b93..2293c89720 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -44,7 +44,7 @@ class AstUtilTest(test.TestCase): node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) - source, _ = compiler.ast_to_source(node) + source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b') def test_rename_symbols_attributes(self): @@ -54,7 +54,7 @@ class AstUtilTest(test.TestCase): node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) - source, _ = compiler.ast_to_source(node) + source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') def test_rename_symbols_annotations(self): @@ -97,10 +97,10 @@ class AstUtilTest(test.TestCase): d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. - output = parser.parse_str('b = 3') - output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) - result, _ = compiler.ast_to_object(output) - self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + node = parser.parse_str('def f(b): pass').body[0] + node.body.append(ast.Return(d)) + result, _ = compiler.ast_to_object(node) + self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) def assertMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) @@ -130,8 +130,8 @@ class AstUtilTest(test.TestCase): 'super(Bar, _).__init__(_)') def _mock_apply_fn(self, target, source): - target, _ = compiler.ast_to_source(target) - source, _ = compiler.ast_to_source(source) + target = compiler.ast_to_source(target) + source = compiler.ast_to_source(source) self._invocation_counts[(target.strip(), source.strip())] += 1 def test_apply_to_single_assignments_dynamic_unpack(self): @@ -157,24 +157,40 @@ class AstUtilTest(test.TestCase): }) def test_parallel_walk(self): - ret = ast.Return( - ast.BinOp( - op=ast.Add(), - left=ast.Name(id='a', ctx=ast.Load()), - right=ast.Num(1))) - node = ast.FunctionDef( - name='f', - args=ast.arguments( - args=[ast.Name(id='a', ctx=ast.Param())], - vararg=None, - kwarg=None, - defaults=[]), - body=[ret], - decorator_list=[], - returns=None) + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b) + def test_parallel_walk_inconsistent_trees(self): + node_1 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + node_2 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + (a * 2) + """)) + node_3 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 2 + """)) + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_2): + pass + # There is not particular reason to reject trees that differ only in the + # value of a constant. + # TODO(mdan): This should probably be allowed. + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_3): + pass + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py index 25fec7fd53..ba51dcf285 100644 --- a/tensorflow/contrib/autograph/pyct/cfg.py +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -67,10 +67,8 @@ class Node(object): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): - source, _ = compiler.ast_to_source(self.ast_node.context_expr) - return source.strip() - source, _ = compiler.ast_to_source(self.ast_node) - return source.strip() + return compiler.ast_to_source(self.ast_node.context_expr).strip() + return compiler.ast_to_source(self.ast_node).strip() class Graph( diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index 81983a5ecb..aefbc69d8c 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -43,7 +43,7 @@ class AnfTransformerTest(test.TestCase): return a node, _ = parser.parse_entity(test_function) - node = anf.transform(node, self._simple_source_info()) + node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function()) diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py index c90a5e89c2..f9cee10962 100644 --- a/tensorflow/contrib/autograph/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -30,44 +30,7 @@ import tempfile import astor import gast -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import origin_info -from tensorflow.contrib.autograph.pyct import parser - - -def _build_source_map(node, code): - """Return the Python objects represented by given AST. - - Compiling the AST code this way ensures that the source code is readable by - e.g. `pdb` or `inspect`. - - Args: - node: An AST node of the original generated code, before the source code is - generated. - code: The string representation of the source code for the newly generated - code. - - Returns: - Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph - generated code. - """ - # After we have the final generated code we reparse it to get the final line - # numbers. Then we walk through the generated and original ASTs in parallel - # to build the mapping between the user and generated code. - new_node = parser.parse_str(code) - origin_info.resolve(new_node, code) - source_mapping = {} - for before, after in ast_util.parallel_walk(node, new_node): - # Need both checks because if origin information is ever copied over to new - # nodes then we need to rely on the fact that only the original user code - # has the origin annotation. - if (anno.hasanno(before, anno.Basic.ORIGIN) and - anno.hasanno(after, anno.Basic.ORIGIN)): - source_info = anno.getanno(before, anno.Basic.ORIGIN) - new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number - source_mapping[new_line_number] = source_info - return source_mapping def ast_to_source(node, indentation=' '): @@ -81,24 +44,28 @@ def ast_to_source(node, indentation=' '): code: The source code generated from the AST object source_mapping: A mapping between the user and AutoGraph generated code. """ - original_node = node - if isinstance(node, gast.AST): - node = gast.gast_to_ast(node) + if not isinstance(node, (list, tuple)): + node = (node,) generator = astor.codegen.SourceGenerator(indentation, False, astor.string_repr.pretty_string) - generator.visit(node) - generator.result.append('\n') + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + # In some versions of Python, literals may appear as actual values. This # ensures everything is string. code = map(str, generator.result) code = astor.source_repr.pretty_source(code).lstrip() - source_mapping = _build_source_map(original_node, code) - return code, source_mapping + return code -def ast_to_object(node, +def ast_to_object(nodes, indentation=' ', + include_source_map=False, source_prefix=None, delete_on_exit=True): """Return the Python objects represented by given AST. @@ -107,42 +74,46 @@ def ast_to_object(node, e.g. `pdb` or `inspect`. Args: - node: The code to compile, as an AST object. - indentation: The string to use for indentation. - source_prefix: Optional string to print as-is into the source file. - delete_on_exit: Whether to delete the temporary file used for compilation on - exit. + nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST + object. + indentation: Text, the string to use for indentation. + include_source_map: bool, whether to attach a source map to the compiled + object. Also see origin_info.py. + source_prefix: Optional[Text], string to print as-is into the source file. + delete_on_exit: bool, whether to delete the temporary file used for + compilation on exit. Returns: - compiled_node: A module object containing the compiled source code. + compiled_nodes: A module object containing the compiled source code. source: The source code of the compiled object Raises: ValueError: If ag_source_map__ is already in the namespace of the compiled - node. + nodes. """ - # code_source_mapping does not yet include the offsets from import statements. - source, code_source_mapping = ast_to_source(node, indentation=indentation) + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + source = ast_to_source(nodes, indentation=indentation) + + if source_prefix: + source = source_prefix + '\n' + source with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - # TODO(znado): move into an _offset_source_map() helper function. - # Need to offset the generated line numbers by the number of import lines. - if source_prefix: - num_import_lines = source_prefix.count('\n') + 1 - else: - num_import_lines = 0 - source_mapping = {} - for line_number, original_position in code_source_mapping.items(): - source_map_key = origin_info.CodeLocation( - file_path=f.name, line_number=line_number + num_import_lines) - source_mapping[source_map_key] = original_position module_name = os.path.basename(f.name[:-3]) - if source_prefix: - f.write(source_prefix) - f.write('\n') f.write(source) + + if isinstance(nodes, (list, tuple)): + indices = range(-len(nodes), 0) + else: + indices = (-1,) + + if include_source_map: + source_map = origin_info.source_map(nodes, source, f.name, indices) + + # TODO(mdan): Try flush() and delete=False instead. if delete_on_exit: atexit.register(lambda: os.remove(f.name)) - compiled_node = imp.load_source(module_name, f.name) + compiled_nodes = imp.load_source(module_name, f.name) # TODO(znado): Clean this up so we don't need to attach it to the namespace. # TODO(znado): This does not work for classes because their methods share a @@ -158,11 +129,13 @@ def ast_to_object(node, # is hard, and this cleanly fixes the # issues encountered with nested functions because this is attached to the # outermost one. - source_map_name = 'ag_source_map__' - if source_map_name in compiled_node.__dict__: - raise ValueError('cannot convert %s because is has namespace attribute ' - '"%s", which is reserved for AutoGraph.' % - (compiled_node, source_map_name)) - compiled_node.__dict__[source_map_name] = source_mapping - - return compiled_node, source + if include_source_map: + # TODO(mdan): This name should be decided by the caller. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_nodes.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_nodes, source_map_name)) + compiled_nodes.__dict__[source_map_name] = source_map + + return compiled_nodes, source diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py index e29fa9324c..cf783da6a3 100644 --- a/tensorflow/contrib/autograph/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -59,7 +59,7 @@ class CompilerTest(test.TestCase): value=gast.Str('c')) ]) - source, _ = compiler.ast_to_source(node, indentation=' ') + source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index 614e346634..1aad2f47df 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -22,49 +22,115 @@ import collections import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.util import tf_inspect -class CodeLocation( - collections.namedtuple('CodeLocation', ('file_path', 'line_number'))): - """Location of a line of code. +class LineLocation( + collections.namedtuple('LineLocation', ('filename', 'lineno'))): + """Similar to Location, but without column information. Attributes: - file_path: text, the full path to the file containing the code. - line_number: Int, the 1-based line number of the code in its file. + filename: Text + lineno: int, 1-based """ pass +class Location( + collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): + """Encodes code location information. + + Attributes: + filename: Text + lineno: int, 1-based + col_offset: int + """ + + @property + def line_loc(self): + return LineLocation(self.filename, self.lineno) + + class OriginInfo( - collections.namedtuple('OriginInfo', - ('file_path', 'function_name', 'line_number', - 'column_offset', 'source_code_line'))): + collections.namedtuple( + 'OriginInfo', + ('loc', 'function_name', 'source_code_line'))): """Container for information about the source code before conversion. - Instances of this class contain information about the source code that - transformed code originated from. Examples include: - * line number - * file name - * original user code + Attributes: + loc: Location + function_name: Optional[Text] + source_code_line: Text """ def as_frame(self): - """Makes a traceback frame tuple. - - Returns: - A tuple of (file_path, line_number, function_name, source_code_line). - """ - return (self.file_path, self.line_number, self.function_name, + """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" + return (self.loc.filename, self.loc.lineno, self.function_name, self.source_code_line) +# TODO(mdan): This source map should be a class - easier to refer to. +def source_map(nodes, code, filename, indices_in_code): + """Creates a source map between an annotated AST and the code it compiles to. + + Args: + nodes: Iterable[ast.AST, ...] + code: Text + filename: Optional[Text] + indices_in_code: Union[int, Iterable[int, ...]], the positions at which + nodes appear in code. The parser always returns a module when parsing + code. This argument indicates the position in that module's body at + which the corresponding of node should appear. + + Returns: + Dict[CodeLocation, OriginInfo], mapping locations in code to locations + indicated by origin annotations in node. + """ + reparsed_nodes = parser.parse_str(code) + reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] + + resolve(reparsed_nodes, code) + result = {} + + for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): + # Note: generated code might not be mapped back to its origin. + # TODO(mdan): Generated code should always be mapped to something. + origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) + final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) + if origin_info is None or final_info is None: + continue + + line_loc = LineLocation(filename, final_info.loc.lineno) + + existing_origin = result.get(line_loc) + if existing_origin is not None: + # Overlaps may exist because of child nodes, but almost never to + # different line locations. Exception make decorated functions, where + # both lines are mapped to the same line in the AST. + + # Line overlaps: keep bottom node. + if existing_origin.loc.line_loc == origin_info.loc.line_loc: + if existing_origin.loc.lineno >= origin_info.loc.lineno: + continue + + # In case of overlaps, keep the leftmost node. + if existing_origin.loc.col_offset <= origin_info.loc.col_offset: + continue + + result[line_loc] = origin_info + + return result + + # TODO(znado): Consider refactoring this into a Visitor. -def resolve(node, source, function=None): +# TODO(mdan): Does this work correctly with inner functions? +def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: - node: The AST node for the function whose body nodes will be annotated. + nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it @@ -76,25 +142,32 @@ def resolve(node, source, function=None): A tuple of the AST node for function and a String containing its source code. """ + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None + source_lines = source.split('\n') - for n in gast.walk(node): - if hasattr(n, 'lineno'): - # n.lineno is relative to the start of the enclosing function, so need to - # offset it by the line of the function. - source_code_line = source_lines[n.lineno - 1] + for node in nodes: + for n in gast.walk(node): + if not hasattr(n, 'lineno'): + continue + + lineno_in_body = n.lineno + + source_code_line = source_lines[lineno_in_body - 1] if function: - source_lineno = n.lineno + function_lineno - 1 + source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: - source_lineno = n.lineno + source_lineno = lineno_in_body function_name = None - anno.setanno( - n, anno.Basic.ORIGIN, - OriginInfo(function_filepath, function_name, source_lineno, - n.col_offset, source_code_line)) + + location = Location(function_filepath, source_lineno, n.col_offset) + origin = OriginInfo(location, function_name, source_code_line) + anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py new file mode 100644 index 0000000000..6d7d8b1622 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -0,0 +1,101 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for origin_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class OriginInfoTest(test.TestCase): + + def test_source_map(self): + + def test_fn(x): + if x > 0: + x += 1 + return x + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + # Insert a traced line. + new_node = parser.parse_str('x = abs(x)').body[0] + anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) + fn_node.body.insert(0, new_node) + + # Insert an untraced line. + fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) + + modified_source = compiler.ast_to_source(fn_node) + + source_map = origin_info.source_map(fn_node, modified_source, + 'test_filename', [0]) + + loc = origin_info.LineLocation('test_filename', 1) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertEqual(origin.loc.lineno, 1) + + # The untraced line, inserted second. + loc = origin_info.LineLocation('test_filename', 2) + self.assertFalse(loc in source_map) + + # The traced line, inserted first. + loc = origin_info.LineLocation('test_filename', 3) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + loc = origin_info.LineLocation('test_filename', 4) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + def test_resolve(self): + + def test_fn(x): + """Docstring.""" + return x # comment + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 1) + self.assertEqual(origin.loc.col_offset, 0) + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + + origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 2) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' """Docstring."""') + + origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 3) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' return x # comment') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py index c961efa892..112ed46a1e 100644 --- a/tensorflow/contrib/autograph/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -37,6 +37,7 @@ def parse_entity(entity): def parse_str(src): """Returns the AST of given piece of code.""" + # TODO(mdan): This should exclude the module things are autowrapped in. return gast.parse(src) -- cgit v1.2.3