From c1e050cc75c6ced7b68a2349a012b2e5a3d04538 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 24 Sep 2018 07:58:00 -0700 Subject: Rename source_map to create_source_map. Reorganize the tests to be clearer about the expected functionality. PiperOrigin-RevId: 214266947 --- tensorflow/python/autograph/pyct/compiler.py | 2 +- tensorflow/python/autograph/pyct/origin_info.py | 2 +- .../python/autograph/pyct/origin_info_test.py | 59 ++++++++++------------ 3 files changed, 28 insertions(+), 35 deletions(-) (limited to 'tensorflow/python/autograph') diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py index 9e1b6bdbe8..37f3e72f6e 100644 --- a/tensorflow/python/autograph/pyct/compiler.py +++ b/tensorflow/python/autograph/pyct/compiler.py @@ -108,7 +108,7 @@ def ast_to_object(nodes, indices = (-1,) if include_source_map: - source_map = origin_info.source_map(nodes, source, f.name, indices) + source_map = origin_info.create_source_map(nodes, source, f.name, indices) # TODO(mdan): Try flush() and delete=False instead. if delete_on_exit: diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index 4c7c4165ef..102bd42c91 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -75,7 +75,7 @@ class OriginInfo( # TODO(mdan): This source map should be a class - easier to refer to. -def source_map(nodes, code, filename, indices_in_code): +def create_source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py index 6b9c30dbd0..3b1d5f2040 100644 --- a/tensorflow/python/autograph/pyct/origin_info_test.py +++ b/tensorflow/python/autograph/pyct/origin_info_test.py @@ -27,49 +27,41 @@ from tensorflow.python.platform import test class OriginInfoTest(test.TestCase): - def test_source_map(self): + def test_create_source_map(self): def test_fn(x): - if x > 0: - x += 1 - return x - - node, source = parser.parse_entity(test_fn) + return x + 1 + + node, _ = parser.parse_entity(test_fn) + fake_origin = origin_info.OriginInfo( + loc=origin_info.Location('fake_filename', 3, 7), + function_name='fake_function_name', + source_code_line='fake source line', + comment=None) fn_node = node.body[0] - origin_info.resolve(fn_node, source) - - # Insert a traced line. - new_node = parser.parse_str('x = abs(x)').body[0] - anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) - fn_node.body.insert(0, new_node) + anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) + converted_code = compiler.ast_to_source(fn_node) - # Insert an untraced line. - fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) + source_map = origin_info.create_source_map( + fn_node, converted_code, 'test_filename', [0]) - modified_source = compiler.ast_to_source(fn_node) + loc = origin_info.LineLocation('test_filename', 2) + self.assertIn(loc, source_map) + self.assertIs(source_map[loc], fake_origin) - source_map = origin_info.source_map(fn_node, modified_source, - 'test_filename', [0]) + def test_source_map_no_origin(self): - loc = origin_info.LineLocation('test_filename', 1) - origin = source_map[loc] - self.assertEqual(origin.source_code_line, 'def test_fn(x):') - self.assertEqual(origin.loc.lineno, 1) + def test_fn(x): + return x + 1 - # The untraced line, inserted second. - loc = origin_info.LineLocation('test_filename', 2) - self.assertFalse(loc in source_map) + node, _ = parser.parse_entity(test_fn) + fn_node = node.body[0] + converted_code = compiler.ast_to_source(fn_node) - # The traced line, inserted first. - loc = origin_info.LineLocation('test_filename', 3) - origin = source_map[loc] - self.assertEqual(origin.source_code_line, ' if x > 0:') - self.assertEqual(origin.loc.lineno, 2) + source_map = origin_info.create_source_map( + fn_node, converted_code, 'test_filename', [0]) - loc = origin_info.LineLocation('test_filename', 4) - origin = source_map[loc] - self.assertEqual(origin.source_code_line, ' if x > 0:') - self.assertEqual(origin.loc.lineno, 2) + self.assertEqual(len(source_map), 0) def test_resolve(self): @@ -79,6 +71,7 @@ class OriginInfoTest(test.TestCase): node, source = parser.parse_entity(test_fn) fn_node = node.body[0] + origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) -- cgit v1.2.3