From 1665eadb04da2446e5a14d4e5f8947f0eeab8215 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Fri, 3 Aug 2018 14:31:21 -0700 Subject: Include same-line comments in origin_info. PiperOrigin-RevId: 207325109 --- .../autograph/converters/error_handlers_test.py | 6 ++++-- tensorflow/contrib/autograph/core/errors_test.py | 3 ++- tensorflow/contrib/autograph/pyct/origin_info.py | 19 ++++++++++++++++--- tensorflow/contrib/autograph/pyct/origin_info_test.py | 3 +++ 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py index cd74e5f18f..5d61b220af 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers_test.py +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -34,8 +34,10 @@ class ErrorHandlersTest(converter_testing.TestCase): raise ValueError() node, ctx = self.prepare(test_fn, {}) - anno.setanno(node, anno.Basic.ORIGIN, - origin_info.OriginInfo(None, None, None)) + anno.setanno( + node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, 'test_function_name', 'test_code', + 'test_comment')) node = error_handlers.transform(node, ctx) with self.compiled(node, {}) as result: with self.assertRaises(errors.GraphConstructionError): diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py index c0e2c74e47..404c1f5456 100644 --- a/tensorflow/contrib/autograph/core/errors_test.py +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -43,7 +43,8 @@ class RuntimeErrorsTest(test.TestCase): filename = tf_inspect.getsourcefile(function) lineno += line_offset loc = origin_info.LineLocation(filename, lineno) - origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code') + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code', + 'test_comment') return loc, origin def test_improved_errors_basic(self): diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index 9f98e48a6a..b60651a30e 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -18,8 +18,10 @@ from __future__ import division from __future__ import print_function import collections +import tokenize import gast +import six from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util @@ -56,13 +58,14 @@ class Location( class OriginInfo( collections.namedtuple( 'OriginInfo', - ('loc', 'function_name', 'source_code_line'))): + ('loc', 'function_name', 'source_code_line', 'comment'))): """Container for information about the source code before conversion. Attributes: loc: Location function_name: Optional[Text] source_code_line: Text + comment: Optional[Text] """ def as_frame(self): @@ -152,6 +155,15 @@ def resolve(nodes, source, function=None): function_lineno = None function_filepath = None + # TODO(mdan): Pull this to a separate utility. + code_reader = six.StringIO(source) + comment_map = {} + for token in tokenize.generate_tokens(code_reader.readline): + tok_type, tok_string, loc, _, _ = token + srow, _ = loc + if tok_type == tokenize.COMMENT: + comment_map[srow] = tok_string.strip()[1:].strip() + source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): @@ -162,12 +174,13 @@ def resolve(nodes, source, function=None): source_code_line = source_lines[lineno_in_body - 1] if function: - source_lineno = function_lineno + lineno_in_body - 1 + source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: source_lineno = lineno_in_body function_name = None location = Location(function_filepath, source_lineno, n.col_offset) - origin = OriginInfo(location, function_name, source_code_line) + origin = OriginInfo(location, function_name, + source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py index 6d7d8b1622..eeaa13007e 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info_test.py +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -85,16 +85,19 @@ class OriginInfoTest(test.TestCase): self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') + self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') + self.assertEqual(origin.comment, 'comment') if __name__ == '__main__': -- cgit v1.2.3