aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-24 07:58:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 08:01:53 -0700
commitc1e050cc75c6ced7b68a2349a012b2e5a3d04538 (patch)
treef90a0fed33d2a052ace011b2fb293d463ec32234 /tensorflow/python/autograph
parent64498def97852cc359209576703c7b788ba839e9 (diff)
Rename source_map to create_source_map. Reorganize the tests to be clearer about the expected functionality.
PiperOrigin-RevId: 214266947
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py59
3 files changed, 28 insertions, 35 deletions
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
index 9e1b6bdbe8..37f3e72f6e 100644
--- a/tensorflow/python/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -108,7 +108,7 @@ def ast_to_object(nodes,
indices = (-1,)
if include_source_map:
- source_map = origin_info.source_map(nodes, source, f.name, indices)
+ source_map = origin_info.create_source_map(nodes, source, f.name, indices)
# TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 4c7c4165ef..102bd42c91 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -75,7 +75,7 @@ class OriginInfo(
# TODO(mdan): This source map should be a class - easier to refer to.
-def source_map(nodes, code, filename, indices_in_code):
+def create_source_map(nodes, code, filename, indices_in_code):
"""Creates a source map between an annotated AST and the code it compiles to.
Args:
diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index 6b9c30dbd0..3b1d5f2040 100644
--- a/tensorflow/python/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -27,49 +27,41 @@ from tensorflow.python.platform import test
class OriginInfoTest(test.TestCase):
- def test_source_map(self):
+ def test_create_source_map(self):
def test_fn(x):
- if x > 0:
- x += 1
- return x
-
- node, source = parser.parse_entity(test_fn)
+ return x + 1
+
+ node, _ = parser.parse_entity(test_fn)
+ fake_origin = origin_info.OriginInfo(
+ loc=origin_info.Location('fake_filename', 3, 7),
+ function_name='fake_function_name',
+ source_code_line='fake source line',
+ comment=None)
fn_node = node.body[0]
- origin_info.resolve(fn_node, source)
-
- # Insert a traced line.
- new_node = parser.parse_str('x = abs(x)').body[0]
- anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
- fn_node.body.insert(0, new_node)
+ anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
+ converted_code = compiler.ast_to_source(fn_node)
- # Insert an untraced line.
- fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- modified_source = compiler.ast_to_source(fn_node)
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertIn(loc, source_map)
+ self.assertIs(source_map[loc], fake_origin)
- source_map = origin_info.source_map(fn_node, modified_source,
- 'test_filename', [0])
+ def test_source_map_no_origin(self):
- loc = origin_info.LineLocation('test_filename', 1)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, 'def test_fn(x):')
- self.assertEqual(origin.loc.lineno, 1)
+ def test_fn(x):
+ return x + 1
- # The untraced line, inserted second.
- loc = origin_info.LineLocation('test_filename', 2)
- self.assertFalse(loc in source_map)
+ node, _ = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ converted_code = compiler.ast_to_source(fn_node)
- # The traced line, inserted first.
- loc = origin_info.LineLocation('test_filename', 3)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- loc = origin_info.LineLocation('test_filename', 4)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(len(source_map), 0)
def test_resolve(self):
@@ -79,6 +71,7 @@ class OriginInfoTest(test.TestCase):
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
+
origin_info.resolve(fn_node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)