aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-08-03 14:31:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 14:42:16 -0700
commit1665eadb04da2446e5a14d4e5f8947f0eeab8215 (patch)
treec2a60d564cae8cc9b1522d1796d114b43501cd2e /tensorflow/contrib
parenteaa3e88ec3322fd0aa4224040215c3c29a752613 (diff)
Include same-line comments in origin_info.
PiperOrigin-RevId: 207325109
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers_test.py6
-rw-r--r--tensorflow/contrib/autograph/core/errors_test.py3
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py19
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info_test.py3
4 files changed, 25 insertions, 6 deletions
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py
index cd74e5f18f..5d61b220af 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py
@@ -34,8 +34,10 @@ class ErrorHandlersTest(converter_testing.TestCase):
raise ValueError()
node, ctx = self.prepare(test_fn, {})
- anno.setanno(node, anno.Basic.ORIGIN,
- origin_info.OriginInfo(None, None, None))
+ anno.setanno(
+ node, anno.Basic.ORIGIN,
+ origin_info.OriginInfo(None, 'test_function_name', 'test_code',
+ 'test_comment'))
node = error_handlers.transform(node, ctx)
with self.compiled(node, {}) as result:
with self.assertRaises(errors.GraphConstructionError):
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py
index c0e2c74e47..404c1f5456 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/contrib/autograph/core/errors_test.py
@@ -43,7 +43,8 @@ class RuntimeErrorsTest(test.TestCase):
filename = tf_inspect.getsourcefile(function)
lineno += line_offset
loc = origin_info.LineLocation(filename, lineno)
- origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
+ origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code',
+ 'test_comment')
return loc, origin
def test_improved_errors_basic(self):
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
index 9f98e48a6a..b60651a30e 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -18,8 +18,10 @@ from __future__ import division
from __future__ import print_function
import collections
+import tokenize
import gast
+import six
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
@@ -56,13 +58,14 @@ class Location(
class OriginInfo(
collections.namedtuple(
'OriginInfo',
- ('loc', 'function_name', 'source_code_line'))):
+ ('loc', 'function_name', 'source_code_line', 'comment'))):
"""Container for information about the source code before conversion.
Attributes:
loc: Location
function_name: Optional[Text]
source_code_line: Text
+ comment: Optional[Text]
"""
def as_frame(self):
@@ -152,6 +155,15 @@ def resolve(nodes, source, function=None):
function_lineno = None
function_filepath = None
+ # TODO(mdan): Pull this to a separate utility.
+ code_reader = six.StringIO(source)
+ comment_map = {}
+ for token in tokenize.generate_tokens(code_reader.readline):
+ tok_type, tok_string, loc, _, _ = token
+ srow, _ = loc
+ if tok_type == tokenize.COMMENT:
+ comment_map[srow] = tok_string.strip()[1:].strip()
+
source_lines = source.split('\n')
for node in nodes:
for n in gast.walk(node):
@@ -162,12 +174,13 @@ def resolve(nodes, source, function=None):
source_code_line = source_lines[lineno_in_body - 1]
if function:
- source_lineno = function_lineno + lineno_in_body - 1
+ source_lineno = function_lineno + lineno_in_body
function_name = function.__name__
else:
source_lineno = lineno_in_body
function_name = None
location = Location(function_filepath, source_lineno, n.col_offset)
- origin = OriginInfo(location, function_name, source_code_line)
+ origin = OriginInfo(location, function_name,
+ source_code_line, comment_map.get(source_lineno))
anno.setanno(n, anno.Basic.ORIGIN, origin)
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py
index 6d7d8b1622..eeaa13007e 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py
@@ -85,16 +85,19 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(origin.loc.lineno, 1)
self.assertEqual(origin.loc.col_offset, 0)
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+ self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 2)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' """Docstring."""')
+ self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 3)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' return x # comment')
+ self.assertEqual(origin.comment, 'comment')
if __name__ == '__main__':