aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-07-25 18:01:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 18:08:37 -0700
commitb24037513f12a5812a21b7ea92ff904ee9ea6cd8 (patch)
treeb747b66cbe2ef8397b885eb073f77bed483a53b6 /tensorflow/contrib/autograph
parent59305b118a9ed56d733b414e8ee1a272dc66466a (diff)
Fix bug in parallel_walk, along with a few structural bugs that this fix revealed:
1. The conversion process was inconsistently packaging the final output into modules or lists. This CL uniformly uses a list of nodes as output from all *_to_graph functions. As a side effect, converter_testing.py asserts that the output is always a single node and extracts it, so there is no need for tests to unpack it any more. Modify the compiler to skip generating a source map by default. 2. The class converter was incorrectly saving the superclass value to the string 'object' instead of the symbol `object`. Additional refactoring that was caught along: Simplify the source mapping code, move it to origin_info.py, add tests and additional checks. Slightly simplify the error rewriting mechanism. PiperOrigin-RevId: 206087110
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/asserts_test.py2
-rw-r--r--tensorflow/contrib/autograph/converters/directives_test.py6
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers_test.py6
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards_test.py12
-rw-r--r--tensorflow/contrib/autograph/converters/slices_test.py6
-rw-r--r--tensorflow/contrib/autograph/core/converter_testing.py4
-rw-r--r--tensorflow/contrib/autograph/core/errors.py116
-rw-r--r--tensorflow/contrib/autograph/core/errors_test.py108
-rw-r--r--tensorflow/contrib/autograph/impl/api.py21
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py34
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py29
-rw-r--r--tensorflow/contrib/autograph/pyct/BUILD10
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util.py87
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util_test.py62
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg.py6
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler.py127
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler_test.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py137
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info_test.py101
-rw-r--r--tensorflow/contrib/autograph/pyct/parser.py1
22 files changed, 539 insertions, 344 deletions
diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py
index 9c58ae3acc..38faba45df 100644
--- a/tensorflow/contrib/autograph/converters/asserts_test.py
+++ b/tensorflow/contrib/autograph/converters/asserts_test.py
@@ -35,7 +35,7 @@ class AssertsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = asserts.transform(node, ctx)
- self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+ self.assertTrue(isinstance(node.body[0].value, gast.Call))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py
index 5f798a5b76..a573ba5850 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/contrib/autograph/converters/directives_test.py
@@ -38,7 +38,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
- def_, = anno.getanno(node.body[0].body[0].targets[0],
+ def_, = anno.getanno(node.body[0].targets[0],
anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type]
self.assertEqual(d['dtype'].s, 'a')
@@ -52,7 +52,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
- def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type]
self.assertEqual(d['dtype'].n, 1)
self.assertEqual(d['shape'].n, 2)
@@ -67,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
- d = anno.getanno(node.body[0].body[1], AgAnno.DIRECTIVES)
+ d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
d = d[directives.set_loop_options]
self.assertEqual(d['parallel_iterations'].n, 10)
self.assertEqual(d['back_prop'].id, 'a')
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py
index 878526c8b4..cd74e5f18f 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py
@@ -34,11 +34,13 @@ class ErrorHandlersTest(converter_testing.TestCase):
raise ValueError()
node, ctx = self.prepare(test_fn, {})
- anno.setanno(node.body[0], anno.Basic.ORIGIN,
- origin_info.OriginInfo('test_path', None, None, None, None))
+ anno.setanno(node, anno.Basic.ORIGIN,
+ origin_info.OriginInfo(None, None, None))
node = error_handlers.transform(node, ctx)
with self.compiled(node, {}) as result:
with self.assertRaises(errors.GraphConstructionError):
+ # Here we just assert that the handler works. Its correctness is
+ # verified by errors_test.py.
result.test_fn()
def test_no_origin_annotation(self):
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index f906918ac0..996e99ee61 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -79,7 +79,7 @@ class ListTest(converter_testing.TestCase):
ns = {'special_functions': special_functions}
node, ctx = self.prepare(test_fn, ns)
- def_, = anno.getanno(node.body[0].body[0].targets[0],
+ def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32'),
@@ -114,7 +114,7 @@ class ListTest(converter_testing.TestCase):
return tf.stack(l)
node, ctx = self.prepare(test_fn, {})
- def_, = anno.getanno(node.body[0].body[0].targets[0],
+ def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
index de1874321e..bee512abbc 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
@@ -43,7 +43,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
@@ -64,7 +64,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
@@ -84,7 +84,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, control_flow_ops.Assert) as result:
with self.test_session() as sess:
@@ -104,7 +104,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign_add) as result:
with self.test_session() as sess:
@@ -125,7 +125,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body[0].body), 1)
+ self.assertEqual(len(node.body[0].body), 1)
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.test_session() as sess:
@@ -147,7 +147,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign,
state_ops.assign_add) as result:
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py
index 3c0f81e8bc..c822d53a4a 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/contrib/autograph/converters/slices_test.py
@@ -38,7 +38,7 @@ class SliceTest(converter_testing.TestCase):
return l[1]
node, ctx = self.prepare(test_fn, {})
- def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
@@ -59,11 +59,11 @@ class SliceTest(converter_testing.TestCase):
return l[1]
node, ctx = self.prepare(test_fn, {})
- def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
- def_, = anno.getanno(node.body[0].body[0].body[0].targets[0],
+ def_, = anno.getanno(node.body[0].body[0].targets[0],
anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.float32')
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py
index 2025e32817..5ee2c3fffd 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/contrib/autograph/core/converter_testing.py
@@ -94,7 +94,8 @@ class TestCase(test.TestCase):
return 7
try:
- result, source = compiler.ast_to_object(node)
+ result, source = compiler.ast_to_object(node, include_source_map=True)
+
result.tf = self.make_fake_mod('fake_tf', *symbols)
fake_ag = self.make_fake_mod('fake_ag', converted_call)
fake_ag.__dict__.update(operators.__dict__)
@@ -144,6 +145,7 @@ class TestCase(test.TestCase):
recursive=True,
autograph_decorators=()):
node, source = parser.parse_entity(test_fn)
+ node = node.body[0]
if namer is None:
namer = FakeNamer()
program_ctx = converter.ProgramContext(
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py
index e58745337a..c219b372c1 100644
--- a/tensorflow/contrib/autograph/core/errors.py
+++ b/tensorflow/contrib/autograph/core/errors.py
@@ -31,11 +31,14 @@ import logging
import sys
import traceback
-from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation
+from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import tf_inspect
+# TODO(mdan): Add a superclass common to all errors.
+
+
class GraphConstructionError(Exception):
"""Error for graph construction errors from AutoGraph generated code."""
@@ -65,27 +68,35 @@ class TfRuntimeError(Exception):
return message + ''.join(traceback.format_list(self.custom_traceback))
-def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices):
- """Rewrites the stack frames at the given indices using the given source map.
+def _rewrite_tb(source_map, tb, filter_function_name=None):
+ """Rewrites code references in a traceback.
Args:
- source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
- AG generated code.
- cleaned_traceback: List[Tuple[text, text, text, text]], the current
- traceback.
- stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if
- there are matching source mapping keys.
-
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
+ locations to their origin
+ tb: List[Tuple[Text, Text, Text, Text]], consistent with
+ traceback.extract_tb
+ filter_function_name: Optional[Text], allows restricting restricts the
+ frames to rewrite to a particular function name
Returns:
- None
+ List[Tuple[Text, Text, Text, Text]], the rewritten traceback
"""
- for frame_index in stack_frame_indices:
- # (file_path, line number, function name, code)
- file_path, line_number, _, _ = cleaned_traceback[frame_index]
- source_map_key = CodeLocation(file_path=file_path, line_number=line_number)
- found_mapping = source_map_key in source_map
- if found_mapping:
- cleaned_traceback[frame_index] = source_map[source_map_key].as_frame()
+ new_tb = []
+ for frame in tb:
+ filename, lineno, function_name, _ = frame
+ loc = origin_info.LineLocation(filename, lineno)
+ origin = source_map.get(loc)
+ # TODO(mdan): We shouldn't need the function name at all.
+ # filename + lineno should be sufficient, even if there are multiple source
+ # maps.
+ if origin is not None:
+ if filter_function_name == function_name or filter_function_name is None:
+ new_tb.append(origin.as_frame())
+ else:
+ new_tb.append(frame)
+ else:
+ new_tb.append(frame)
+ return new_tb
# TODO(znado): Make more robust to name changes in the rewriting logic.
@@ -98,18 +109,20 @@ def _remove_rewrite_frames(tb):
return cleaned_tb
+# TODO(mdan): rename to raise_*
def rewrite_graph_construction_error(source_map):
"""Rewrites errors raised by non-AG APIs inside AG generated code.
- Meant to be called from the try/except block inside each AutoGraph generated
- function. Only rewrites the traceback frames corresponding to the function
- that this is called from. When we raise a GraphConstructionError at the end
- it is then caught by calling functions, where they can be responsible for
- rewriting their own frames.
+ This is called from the except handler inside an AutoGraph generated function
+ (that is, during exception handling). Only rewrites the frames corresponding
+ to the function that this is called from, so each function is responsible
+ to call this to have its own frames rewritten.
+
+ This function always raises an error.
Args:
- source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
- AG generated code.
+ source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
+ map belonging to the calling function
Raises:
GraphConstructionError: The rewritten underlying error.
@@ -120,31 +133,19 @@ def rewrite_graph_construction_error(source_map):
assert original_error is not None
try:
_, _, _, func_name, _, _ = tf_inspect.stack()[1]
- # The latest function call is added to the beginning of a traceback, but
- # when rewriting the traceback of multiple function calls the previous
- # functions' except blocks may have already rewritten their own frames so
- # we want to copy over all of the previous frames. We may have rewritten
- # previous frames only if the error is a GraphConstructionError.
if isinstance(original_error, GraphConstructionError):
+ # TODO(mdan): This is incomplete.
+ # The error might have bubbled through a non-converted function.
cleaned_traceback = traceback.extract_tb(e_traceback)
previous_traceback = original_error.custom_traceback
cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
else:
cleaned_traceback = traceback.extract_tb(e_traceback)
- cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)
-
- current_frame_indices = []
- # This code is meant to be called from the try/except block that wraps a
- # function body. Here we look for all frames that came from the function
- # that this wraps, look for any matching line numbers in the source
- # mapping, and then rewrite them if matches are found.
- for fi, frame in enumerate(cleaned_traceback):
- _, _, frame_func_name, _ = frame
- if frame_func_name == func_name:
- current_frame_indices.append(fi)
- break
- if current_frame_indices:
- _rewrite_frame(source_map, cleaned_traceback, current_frame_indices)
+
+ # Remove the frame corresponding to this function call.
+ cleaned_traceback = cleaned_traceback[1:]
+
+ cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name)
if isinstance(original_error, GraphConstructionError):
original_error.custom_traceback = cleaned_traceback
@@ -153,6 +154,7 @@ def rewrite_graph_construction_error(source_map):
new_error = GraphConstructionError(original_error, cleaned_traceback)
except Exception:
logging.exception('Error while rewriting AutoGraph error:')
+ # TODO(mdan): Should reraise here, removing the top frame as well.
raise original_error
else:
raise new_error
@@ -161,18 +163,17 @@ def rewrite_graph_construction_error(source_map):
del e_traceback
+# TODO(mdan): This should be consistent with rewrite_graph_construction_error
+# Both should either raise or return.
def rewrite_tf_runtime_error(error, source_map):
"""Rewrites TensorFlow runtime errors raised by ops created in AG code.
Args:
- error: error_impl.OpError, an TensorFlow error that will have its traceback
- rewritten.
- source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
- AG generated code.
+ error: tf.OpError
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo]
Returns:
- A TfRuntimeError with a traceback rewritten according to the given
- source mapping.
+ TfRuntimeError, the rewritten underlying error.
"""
# Check for cases where we leave a user method and re-enter it in the
# traceback. This is done by looking at the function names when the
@@ -198,15 +199,16 @@ def rewrite_tf_runtime_error(error, source_map):
# The source map keys are (file_path, line_number) so get the set of all user
# file_paths.
try:
- all_user_files = set(k.file_path for k in source_map)
+ all_user_files = set(loc.filename for loc in source_map)
cleaned_traceback = []
last_user_frame_index = None
last_user_user_file_path = None
last_user_user_fn_name = None
+ # TODO(mdan): Simplify this logic.
for fi, frame in enumerate(error.op.traceback):
- frame_file_path, frame_line_number, _, _ = frame
- src_map_key = CodeLocation(
- file_path=frame_file_path, line_number=frame_line_number)
+ frame_file_path, lineno, _, _ = frame
+ lineno -= 1 # Frame line numbers are 1-based.
+ src_map_key = origin_info.LineLocation(frame_file_path, lineno)
if frame_file_path in all_user_files:
if src_map_key in source_map:
original_fn_name = source_map[src_map_key].function_name
@@ -223,8 +225,8 @@ def rewrite_tf_runtime_error(error, source_map):
last_user_user_file_path = frame_file_path
cleaned_traceback.append(frame)
- for fi in range(len(cleaned_traceback)):
- _rewrite_frame(source_map, cleaned_traceback, [fi])
+ cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
+
op_name = error.op.name
op_message = error.message
rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
@@ -263,7 +265,7 @@ def improved_errors(converted_function):
ValueError: If converted_function is not generated by AutoGraph
"""
if (getattr(converted_function, 'ag_source_map', None) is None or
- not converted_function.ag_source_map):
+ not isinstance(converted_function.ag_source_map, dict)):
raise ValueError(
'converted_function must be the result of an autograph.to_graph call')
try:
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py
index 7be54563a1..c0e2c74e47 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/contrib/autograph/core/errors_test.py
@@ -28,88 +28,76 @@ from tensorflow.python.util import tf_inspect
def zero_div():
- return array_ops.constant(10, dtype=dtypes.int32) // 0
+ x = array_ops.constant(10, dtype=dtypes.int32)
+ return x // 0
def zero_div_caller():
- a = zero_div() + 2
- return a
+ return zero_div()
class RuntimeErrorsTest(test.TestCase):
- def setUp(self):
- self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0,
- 'print("hello world!")')
-
- def test_error_replacement(self):
- _, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
- src_map = {
- errors.CodeLocation(
- file_path=__file__, line_number=zero_div_lineno + 1):
- self._fake_origin
- }
+ def fake_origin(self, function, line_offset):
+ _, lineno = tf_inspect.getsourcelines(function)
+ filename = tf_inspect.getsourcefile(function)
+ lineno += line_offset
+ loc = origin_info.LineLocation(filename, lineno)
+ origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
+ return loc, origin
+
+ def test_improved_errors_basic(self):
+ loc, origin = self.fake_origin(zero_div, 2)
+ zero_div_caller.ag_source_map = {loc: origin}
+
+ ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
- z = zero_div_caller()
- zero_div_caller.ag_source_map = src_map
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
- sess.run(z)
- expected = cm.exception
- current_traceback = expected.custom_traceback
- for frame in current_traceback:
- self.assertNotEqual('zero_div', frame[2])
- self.assertTrue(
- any(self._fake_origin.as_frame() == frame
- for frame in current_traceback))
-
- def test_error_not_found(self):
- src_map = {
- errors.CodeLocation(file_path=__file__, line_number=-1):
- self._fake_origin
- }
+ sess.run(ops)
+
+ for frame in cm.exception.custom_traceback:
+ _, _, function_name, _ = frame
+ self.assertNotEqual('zero_div', function_name)
+ self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
+
+ def test_improved_errors_no_matching_lineno(self):
+ loc, origin = self.fake_origin(zero_div, -1)
+ zero_div_caller.ag_source_map = {loc: origin}
+
+ ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
- z = zero_div_caller()
- zero_div_caller.ag_source_map = src_map
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
- sess.run(z)
- expected = cm.exception
- current_traceback = expected.custom_traceback
- self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback))
- for frame in current_traceback:
- self.assertNotEqual(frame, self._fake_origin.as_frame())
-
- def test_rewriting_error(self):
- _, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
- src_map = {
- errors.CodeLocation(
- file_path=__file__, line_number=zero_div_lineno + 1):
- None
- }
- with self.assertRaisesRegexp(tf_errors.InvalidArgumentError,
- 'Integer division by zero'):
- z = zero_div_caller()
- zero_div_caller.ag_source_map = src_map
+ sess.run(ops)
+
+ all_function_names = set()
+ for frame in cm.exception.custom_traceback:
+ _, _, function_name, _ = frame
+ all_function_names.add(function_name)
+ self.assertNotEqual('test_function_name', function_name)
+ self.assertIn('zero_div', all_function_names)
+
+ def test_improved_errors_failures(self):
+ loc, _ = self.fake_origin(zero_div, 2)
+ zero_div_caller.ag_source_map = {loc: 'bogus object'}
+
+ ops = zero_div_caller()
+ with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
- sess.run(z)
+ sess.run(ops)
- def test_no_ag_source_map(self):
+ def test_improved_errors_validation(self):
with self.assertRaisesRegexp(
ValueError,
'converted_function must be the result of an autograph.to_graph call'):
- with errors.improved_errors(None):
- pass
-
- def test_bad_ag_source_map(self):
+ errors.improved_errors(zero_div).__enter__()
with self.assertRaisesRegexp(
ValueError,
'converted_function must be the result of an autograph.to_graph call'):
- src_map = None
- zero_div_caller.ag_source_map = src_map
- with errors.improved_errors(None):
- pass
+ zero_div_caller.ag_source_map = 'not a dict'
+ errors.improved_errors(zero_div_caller).__enter__()
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index f7fe3de5da..ee71f4f9ac 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -23,7 +23,6 @@ from functools import wraps
from enum import Enum
# pylint:disable=g-bad-import-order
-import gast
import six
# pylint:enable=g-bad-import-order
@@ -245,19 +244,21 @@ def to_graph(e,
_, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
arg_types)
- module = gast.Module([])
+ nodes = []
for dep in reversed(program_ctx.dependency_cache.values()):
- module.body.append(dep)
- compiled_node, compiled_src = compiler.ast_to_object(
- module, source_prefix=program_ctx.required_imports)
+ nodes.extend(dep)
+ compiled_module, compiled_src = compiler.ast_to_object(
+ nodes,
+ source_prefix=program_ctx.required_imports,
+ include_source_map=True)
# The compiled code should see everything the entry entity saw.
# TODO(mdan): This might not work well if the call tree spans modules?
for key, val in namespace.items():
# Avoid overwriting entities that have been transformed.
- if key not in compiled_node.__dict__:
- compiled_node.__dict__[key] = val
- compiled_fn = getattr(compiled_node, name)
+ if key not in compiled_module.__dict__:
+ compiled_module.__dict__[key] = val
+ compiled_fn = getattr(compiled_module, name)
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
@@ -270,7 +271,7 @@ def to_graph(e,
'"%s", which is reserved for AutoGraph.' %
(compiled_fn, source_map_attribute_name))
setattr(compiled_fn, source_map_attribute_name,
- compiled_node.__dict__['ag_source_map__'])
+ compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
@@ -308,7 +309,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)[0]
+ compiler.ast_to_source(dep, indentation)
for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 7bd0ba3f2d..57ec739a80 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -164,7 +164,7 @@ def class_to_graph(c, program_ctx):
class_namespace = namespace
else:
class_namespace.update(namespace)
- converted_members[m] = node
+ converted_members[m] = node[0]
namer = program_ctx.new_namer(class_namespace)
class_name = namer.compiled_class_name(c.__name__, c)
@@ -175,10 +175,10 @@ def class_to_graph(c, program_ctx):
# program_ctx.update_name_map(namer)).
output_nodes = []
renames = {}
- bases = []
+ base_names = []
for base in c.__bases__:
if isinstance(object, base):
- bases.append('object')
+ base_names.append('object')
continue
if is_whitelisted_for_graph(base):
alias = namer.new_symbol(base.__name__, ())
@@ -190,28 +190,28 @@ def class_to_graph(c, program_ctx):
else:
# This will trigger a conversion into a class with this name.
alias = namer.compiled_class_name(base.__name__, base)
- bases.append(alias)
+ base_names.append(alias)
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
program_ctx.update_name_map(namer)
# Generate the definition of the converted class.
- output_nodes.append(
- gast.ClassDef(
- class_name,
- bases=bases,
- keywords=[],
- body=list(converted_members.values()),
- decorator_list=[]))
- node = gast.Module(output_nodes)
-
+ bases = [gast.Name(n, gast.Load(), None) for n in base_names]
+ class_def = gast.ClassDef(
+ class_name,
+ bases=bases,
+ keywords=[],
+ body=list(converted_members.values()),
+ decorator_list=[])
# Make a final pass to replace references to the class or its base classes.
# Most commonly, this occurs when making super().__init__() calls.
# TODO(mdan): Making direct references to superclass' superclass will fail.
- node = qual_names.resolve(node)
+ class_def = qual_names.resolve(class_def)
renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
- node = ast_util.rename_symbols(node, renames)
+ class_def = ast_util.rename_symbols(class_def, renames)
+
+ output_nodes.append(class_def)
- return node, class_name, class_namespace
+ return output_nodes, class_name, class_namespace
def _add_reserved_symbol(namespace, name, entity):
@@ -279,7 +279,7 @@ def function_to_graph(f,
program_ctx.update_name_map(namer)
# TODO(mdan): Use this at compilation.
- return node, new_name, namespace
+ return (node,), new_name, namespace
def node_to_graph(node, context, rewrite_errors=True):
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index 207225a1ac..bfc51365a3 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -60,10 +60,11 @@ class ConversionTest(test.TestCase):
return a + b
program_ctx = self._simple_program_ctx()
- ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
- self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
+ nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
+ fn_node, = nodes
+ self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
- self.assertTrue(ns['b'] is b)
+ self.assertIs(ns['b'], b)
def test_entity_to_graph_call_tree(self):
@@ -78,14 +79,11 @@ class ConversionTest(test.TestCase):
self.assertTrue(f in program_ctx.dependency_cache)
self.assertTrue(g in program_ctx.dependency_cache)
- self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
- # need one extra .body[0] in order to step past the try/except wrapper that
- # is added automatically, the other for the with tf.name_scope('f') that is
- # added automatically
- self.assertEqual(
- 'tf__g',
- program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id)
- self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
+ f_node = program_ctx.dependency_cache[f][0]
+ g_node = program_ctx.dependency_cache[g][0]
+ self.assertEqual('tf__f', f_node.name)
+ self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
+ self.assertEqual('tf__g', g_node.name)
def test_entity_to_graph_class_hierarchy(self):
@@ -118,9 +116,9 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestBase in program_ctx.dependency_cache)
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
self.assertEqual('TfTestBase',
- program_ctx.dependency_cache[TestBase].body[-1].name)
+ program_ctx.dependency_cache[TestBase][-1].name)
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass].body[-1].name)
+ program_ctx.dependency_cache[TestSubclass][-1].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self):
@@ -139,10 +137,9 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
self.assertFalse(training.Model in program_ctx.dependency_cache)
self.assertEqual(
- 'Model',
- program_ctx.dependency_cache[TestSubclass].body[0].names[0].name)
+ 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass].body[-1].name)
+ program_ctx.dependency_cache[TestSubclass][-1].name)
def test_entity_to_graph_lambda(self):
f = lambda a: a
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD
index f77a6ab392..ddadc6b96e 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/contrib/autograph/pyct/BUILD
@@ -100,6 +100,16 @@ py_test(
)
py_test(
+ name = "origin_info_test",
+ srcs = ["origin_info_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "parser_test",
srcs = ["parser_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py
index 86e3f56a64..d7453b0781 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import ast
-import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
@@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor):
if v != p:
return self.no_match()
+
def matches(node, pattern):
"""Basic pattern matcher for AST.
@@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn):
apply_fn(target, values)
-def iter_fields(node):
- for field in sorted(node._fields):
- try:
- yield getattr(node, field)
- except AttributeError:
- pass
-
-
-def iter_child_nodes(node):
- for field in iter_fields(node):
- if isinstance(field, gast.AST):
- yield field
- elif isinstance(field, list):
- for item in field:
- if isinstance(item, gast.AST):
- yield item
-
-
-def parallel_walk(node_a, node_b):
- todo_a = collections.deque([node_a])
- todo_b = collections.deque([node_b])
- while todo_a and todo_b:
- node_a = todo_a.popleft()
- node_b = todo_b.popleft()
- todo_a.extend(iter_child_nodes(node_a))
- todo_b.extend(iter_child_nodes(node_b))
- yield node_a, node_b
+def parallel_walk(node, other):
+ """Walks two ASTs in parallel.
+
+ The two trees must have identical structure.
+
+ Args:
+ node: Union[ast.AST, Iterable[ast.AST]]
+ other: Union[ast.AST, Iterable[ast.AST]]
+ Yields:
+ Tuple[ast.AST, ast.AST]
+ Raises:
+ ValueError: if the two trees don't have identical structure.
+ """
+ if isinstance(node, (list, tuple)):
+ node_stack = list(node)
+ else:
+ node_stack = [node]
+
+ if isinstance(other, (list, tuple)):
+ other_stack = list(other)
+ else:
+ other_stack = [other]
+
+ while node_stack and other_stack:
+ assert len(node_stack) == len(other_stack)
+ n = node_stack.pop()
+ o = other_stack.pop()
+
+ if (not isinstance(n, (ast.AST, gast.AST)) or
+ not isinstance(o, (ast.AST, gast.AST)) or
+ n.__class__.__name__ != o.__class__.__name__):
+ raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
+
+ yield n, o
+
+ for f in n._fields:
+ n_child = getattr(n, f, None)
+ o_child = getattr(o, f, None)
+ if f.startswith('__') or n_child is None or o_child is None:
+ continue
+
+ if isinstance(n_child, (list, tuple)):
+ if (not isinstance(o_child, (list, tuple)) or
+ len(n_child) != len(o_child)):
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
+ node_stack.extend(n_child)
+ other_stack.extend(o_child)
+
+ elif isinstance(n_child, (gast.AST, ast.AST)):
+ node_stack.append(n_child)
+ other_stack.append(o_child)
+
+ elif n_child != o_child:
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py
index 981e398b93..2293c89720 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util_test.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py
@@ -44,7 +44,7 @@ class AstUtilTest(test.TestCase):
node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
self.assertIsInstance(node.body[0].value.left.id, str)
- source, _ = compiler.ast_to_source(node)
+ source = compiler.ast_to_source(node)
self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_attributes(self):
@@ -54,7 +54,7 @@ class AstUtilTest(test.TestCase):
node = ast_util.rename_symbols(
node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
- source, _ = compiler.ast_to_source(node)
+ source = compiler.ast_to_source(node)
self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_annotations(self):
@@ -97,10 +97,10 @@ class AstUtilTest(test.TestCase):
d = ast_util.keywords_to_dict(keywords)
# Make sure we generate a usable dict node by attaching it to a variable and
# compiling everything.
- output = parser.parse_str('b = 3')
- output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
- result, _ = compiler.ast_to_object(output)
- self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
+ node = parser.parse_str('def f(b): pass').body[0]
+ node.body.append(ast.Return(d))
+ result, _ = compiler.ast_to_object(node)
+ self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def assertMatch(self, target_str, pattern_str):
node = parser.parse_expression(target_str)
@@ -130,8 +130,8 @@ class AstUtilTest(test.TestCase):
'super(Bar, _).__init__(_)')
def _mock_apply_fn(self, target, source):
- target, _ = compiler.ast_to_source(target)
- source, _ = compiler.ast_to_source(source)
+ target = compiler.ast_to_source(target)
+ source = compiler.ast_to_source(source)
self._invocation_counts[(target.strip(), source.strip())] += 1
def test_apply_to_single_assignments_dynamic_unpack(self):
@@ -157,24 +157,40 @@ class AstUtilTest(test.TestCase):
})
def test_parallel_walk(self):
- ret = ast.Return(
- ast.BinOp(
- op=ast.Add(),
- left=ast.Name(id='a', ctx=ast.Load()),
- right=ast.Num(1)))
- node = ast.FunctionDef(
- name='f',
- args=ast.arguments(
- args=[ast.Name(id='a', ctx=ast.Param())],
- vararg=None,
- kwarg=None,
- defaults=[]),
- body=[ret],
- decorator_list=[],
- returns=None)
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
for child_a, child_b in ast_util.parallel_walk(node, node):
self.assertEqual(child_a, child_b)
+ def test_parallel_walk_inconsistent_trees(self):
+ node_1 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ node_2 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + (a * 2)
+ """))
+ node_3 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 2
+ """))
+ with self.assertRaises(ValueError):
+ for _ in ast_util.parallel_walk(node_1, node_2):
+ pass
+ # There is not particular reason to reject trees that differ only in the
+ # value of a constant.
+ # TODO(mdan): This should probably be allowed.
+ with self.assertRaises(ValueError):
+ for _ in ast_util.parallel_walk(node_1, node_3):
+ pass
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py
index 25fec7fd53..ba51dcf285 100644
--- a/tensorflow/contrib/autograph/pyct/cfg.py
+++ b/tensorflow/contrib/autograph/pyct/cfg.py
@@ -67,10 +67,8 @@ class Node(object):
if isinstance(self.ast_node, gast.FunctionDef):
return 'def %s' % self.ast_node.name
elif isinstance(self.ast_node, gast.withitem):
- source, _ = compiler.ast_to_source(self.ast_node.context_expr)
- return source.strip()
- source, _ = compiler.ast_to_source(self.ast_node)
- return source.strip()
+ return compiler.ast_to_source(self.ast_node.context_expr).strip()
+ return compiler.ast_to_source(self.ast_node).strip()
class Graph(
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index 81983a5ecb..aefbc69d8c 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -43,7 +43,7 @@ class AnfTransformerTest(test.TestCase):
return a
node, _ = parser.parse_entity(test_function)
- node = anf.transform(node, self._simple_source_info())
+ node = anf.transform(node.body[0], self._simple_source_info())
result, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function())
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py
index c90a5e89c2..f9cee10962 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/contrib/autograph/pyct/compiler.py
@@ -30,44 +30,7 @@ import tempfile
import astor
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
-
-
-def _build_source_map(node, code):
- """Return the Python objects represented by given AST.
-
- Compiling the AST code this way ensures that the source code is readable by
- e.g. `pdb` or `inspect`.
-
- Args:
- node: An AST node of the original generated code, before the source code is
- generated.
- code: The string representation of the source code for the newly generated
- code.
-
- Returns:
- Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
- generated code.
- """
- # After we have the final generated code we reparse it to get the final line
- # numbers. Then we walk through the generated and original ASTs in parallel
- # to build the mapping between the user and generated code.
- new_node = parser.parse_str(code)
- origin_info.resolve(new_node, code)
- source_mapping = {}
- for before, after in ast_util.parallel_walk(node, new_node):
- # Need both checks because if origin information is ever copied over to new
- # nodes then we need to rely on the fact that only the original user code
- # has the origin annotation.
- if (anno.hasanno(before, anno.Basic.ORIGIN) and
- anno.hasanno(after, anno.Basic.ORIGIN)):
- source_info = anno.getanno(before, anno.Basic.ORIGIN)
- new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
- source_mapping[new_line_number] = source_info
- return source_mapping
def ast_to_source(node, indentation=' '):
@@ -81,24 +44,28 @@ def ast_to_source(node, indentation=' '):
code: The source code generated from the AST object
source_mapping: A mapping between the user and AutoGraph generated code.
"""
- original_node = node
- if isinstance(node, gast.AST):
- node = gast.gast_to_ast(node)
+ if not isinstance(node, (list, tuple)):
+ node = (node,)
generator = astor.codegen.SourceGenerator(indentation, False,
astor.string_repr.pretty_string)
- generator.visit(node)
- generator.result.append('\n')
+
+ for n in node:
+ if isinstance(n, gast.AST):
+ n = gast.gast_to_ast(n)
+ generator.visit(n)
+ generator.result.append('\n')
+
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
code = map(str, generator.result)
code = astor.source_repr.pretty_source(code).lstrip()
- source_mapping = _build_source_map(original_node, code)
- return code, source_mapping
+ return code
-def ast_to_object(node,
+def ast_to_object(nodes,
indentation=' ',
+ include_source_map=False,
source_prefix=None,
delete_on_exit=True):
"""Return the Python objects represented by given AST.
@@ -107,42 +74,46 @@ def ast_to_object(node,
e.g. `pdb` or `inspect`.
Args:
- node: The code to compile, as an AST object.
- indentation: The string to use for indentation.
- source_prefix: Optional string to print as-is into the source file.
- delete_on_exit: Whether to delete the temporary file used for compilation on
- exit.
+ nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
+ object.
+ indentation: Text, the string to use for indentation.
+ include_source_map: bool, whether to attach a source map to the compiled
+ object. Also see origin_info.py.
+ source_prefix: Optional[Text], string to print as-is into the source file.
+ delete_on_exit: bool, whether to delete the temporary file used for
+ compilation on exit.
Returns:
- compiled_node: A module object containing the compiled source code.
+ compiled_nodes: A module object containing the compiled source code.
source: The source code of the compiled object
Raises:
ValueError: If ag_source_map__ is already in the namespace of the compiled
- node.
+ nodes.
"""
- # code_source_mapping does not yet include the offsets from import statements.
- source, code_source_mapping = ast_to_source(node, indentation=indentation)
+ if not isinstance(nodes, (list, tuple)):
+ nodes = (nodes,)
+
+ source = ast_to_source(nodes, indentation=indentation)
+
+ if source_prefix:
+ source = source_prefix + '\n' + source
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
- # TODO(znado): move into an _offset_source_map() helper function.
- # Need to offset the generated line numbers by the number of import lines.
- if source_prefix:
- num_import_lines = source_prefix.count('\n') + 1
- else:
- num_import_lines = 0
- source_mapping = {}
- for line_number, original_position in code_source_mapping.items():
- source_map_key = origin_info.CodeLocation(
- file_path=f.name, line_number=line_number + num_import_lines)
- source_mapping[source_map_key] = original_position
module_name = os.path.basename(f.name[:-3])
- if source_prefix:
- f.write(source_prefix)
- f.write('\n')
f.write(source)
+
+ if isinstance(nodes, (list, tuple)):
+ indices = range(-len(nodes), 0)
+ else:
+ indices = (-1,)
+
+ if include_source_map:
+ source_map = origin_info.source_map(nodes, source, f.name, indices)
+
+ # TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
- compiled_node = imp.load_source(module_name, f.name)
+ compiled_nodes = imp.load_source(module_name, f.name)
# TODO(znado): Clean this up so we don't need to attach it to the namespace.
# TODO(znado): This does not work for classes because their methods share a
@@ -158,11 +129,13 @@ def ast_to_object(node,
# is hard, and this cleanly fixes the
# issues encountered with nested functions because this is attached to the
# outermost one.
- source_map_name = 'ag_source_map__'
- if source_map_name in compiled_node.__dict__:
- raise ValueError('cannot convert %s because is has namespace attribute '
- '"%s", which is reserved for AutoGraph.' %
- (compiled_node, source_map_name))
- compiled_node.__dict__[source_map_name] = source_mapping
-
- return compiled_node, source
+ if include_source_map:
+ # TODO(mdan): This name should be decided by the caller.
+ source_map_name = 'ag_source_map__'
+ if source_map_name in compiled_nodes.__dict__:
+ raise ValueError('cannot convert %s because is has namespace attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_nodes, source_map_name))
+ compiled_nodes.__dict__[source_map_name] = source_map
+
+ return compiled_nodes, source
diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py
index e29fa9324c..cf783da6a3 100644
--- a/tensorflow/contrib/autograph/pyct/compiler_test.py
+++ b/tensorflow/contrib/autograph/pyct/compiler_test.py
@@ -59,7 +59,7 @@ class CompilerTest(test.TestCase):
value=gast.Str('c'))
])
- source, _ = compiler.ast_to_source(node, indentation=' ')
+ source = compiler.ast_to_source(node, indentation=' ')
self.assertEqual(
textwrap.dedent("""
if 1:
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
index 614e346634..1aad2f47df 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -22,49 +22,115 @@ import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.util import tf_inspect
-class CodeLocation(
- collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
- """Location of a line of code.
+class LineLocation(
+ collections.namedtuple('LineLocation', ('filename', 'lineno'))):
+ """Similar to Location, but without column information.
Attributes:
- file_path: text, the full path to the file containing the code.
- line_number: Int, the 1-based line number of the code in its file.
+ filename: Text
+ lineno: int, 1-based
"""
pass
+class Location(
+ collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
+ """Encodes code location information.
+
+ Attributes:
+ filename: Text
+ lineno: int, 1-based
+ col_offset: int
+ """
+
+ @property
+ def line_loc(self):
+ return LineLocation(self.filename, self.lineno)
+
+
class OriginInfo(
- collections.namedtuple('OriginInfo',
- ('file_path', 'function_name', 'line_number',
- 'column_offset', 'source_code_line'))):
+ collections.namedtuple(
+ 'OriginInfo',
+ ('loc', 'function_name', 'source_code_line'))):
"""Container for information about the source code before conversion.
- Instances of this class contain information about the source code that
- transformed code originated from. Examples include:
- * line number
- * file name
- * original user code
+ Attributes:
+ loc: Location
+ function_name: Optional[Text]
+ source_code_line: Text
"""
def as_frame(self):
- """Makes a traceback frame tuple.
-
- Returns:
- A tuple of (file_path, line_number, function_name, source_code_line).
- """
- return (self.file_path, self.line_number, self.function_name,
+ """Returns a 4-tuple consistent with the return of traceback.extract_tb."""
+ return (self.loc.filename, self.loc.lineno, self.function_name,
self.source_code_line)
+# TODO(mdan): This source map should be a class - easier to refer to.
+def source_map(nodes, code, filename, indices_in_code):
+ """Creates a source map between an annotated AST and the code it compiles to.
+
+ Args:
+ nodes: Iterable[ast.AST, ...]
+ code: Text
+ filename: Optional[Text]
+ indices_in_code: Union[int, Iterable[int, ...]], the positions at which
+ nodes appear in code. The parser always returns a module when parsing
+ code. This argument indicates the position in that module's body at
+ which the corresponding of node should appear.
+
+ Returns:
+ Dict[CodeLocation, OriginInfo], mapping locations in code to locations
+ indicated by origin annotations in node.
+ """
+ reparsed_nodes = parser.parse_str(code)
+ reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
+
+ resolve(reparsed_nodes, code)
+ result = {}
+
+ for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
+ # Note: generated code might not be mapped back to its origin.
+ # TODO(mdan): Generated code should always be mapped to something.
+ origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
+ final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
+ if origin_info is None or final_info is None:
+ continue
+
+ line_loc = LineLocation(filename, final_info.loc.lineno)
+
+ existing_origin = result.get(line_loc)
+ if existing_origin is not None:
+ # Overlaps may exist because of child nodes, but almost never to
+ # different line locations. Exception make decorated functions, where
+ # both lines are mapped to the same line in the AST.
+
+ # Line overlaps: keep bottom node.
+ if existing_origin.loc.line_loc == origin_info.loc.line_loc:
+ if existing_origin.loc.lineno >= origin_info.loc.lineno:
+ continue
+
+ # In case of overlaps, keep the leftmost node.
+ if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
+ continue
+
+ result[line_loc] = origin_info
+
+ return result
+
+
# TODO(znado): Consider refactoring this into a Visitor.
-def resolve(node, source, function=None):
+# TODO(mdan): Does this work correctly with inner functions?
+def resolve(nodes, source, function=None):
"""Adds an origin information to all nodes inside the body of function.
Args:
- node: The AST node for the function whose body nodes will be annotated.
+ nodes: Union[ast.AST, Iterable[ast.AST, ...]]
source: Text, the source code string for the function whose body nodes will
be annotated.
function: Callable, the function that will have all nodes inside of it
@@ -76,25 +142,32 @@ def resolve(node, source, function=None):
A tuple of the AST node for function and a String containing its source
code.
"""
+ if not isinstance(nodes, (list, tuple)):
+ nodes = (nodes,)
+
if function:
_, function_lineno = tf_inspect.getsourcelines(function)
function_filepath = tf_inspect.getsourcefile(function)
else:
function_lineno = None
function_filepath = None
+
source_lines = source.split('\n')
- for n in gast.walk(node):
- if hasattr(n, 'lineno'):
- # n.lineno is relative to the start of the enclosing function, so need to
- # offset it by the line of the function.
- source_code_line = source_lines[n.lineno - 1]
+ for node in nodes:
+ for n in gast.walk(node):
+ if not hasattr(n, 'lineno'):
+ continue
+
+ lineno_in_body = n.lineno
+
+ source_code_line = source_lines[lineno_in_body - 1]
if function:
- source_lineno = n.lineno + function_lineno - 1
+ source_lineno = function_lineno + lineno_in_body
function_name = function.__name__
else:
- source_lineno = n.lineno
+ source_lineno = lineno_in_body
function_name = None
- anno.setanno(
- n, anno.Basic.ORIGIN,
- OriginInfo(function_filepath, function_name, source_lineno,
- n.col_offset, source_code_line))
+
+ location = Location(function_filepath, source_lineno, n.col_offset)
+ origin = OriginInfo(location, function_name, source_code_line)
+ anno.setanno(n, anno.Basic.ORIGIN, origin)
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py
new file mode 100644
index 0000000000..6d7d8b1622
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py
@@ -0,0 +1,101 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for origin_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class OriginInfoTest(test.TestCase):
+
+ def test_source_map(self):
+
+ def test_fn(x):
+ if x > 0:
+ x += 1
+ return x
+
+ node, source = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ origin_info.resolve(fn_node, source)
+
+ # Insert a traced line.
+ new_node = parser.parse_str('x = abs(x)').body[0]
+ anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
+ fn_node.body.insert(0, new_node)
+
+ # Insert an untraced line.
+ fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+
+ modified_source = compiler.ast_to_source(fn_node)
+
+ source_map = origin_info.source_map(fn_node, modified_source,
+ 'test_filename', [0])
+
+ loc = origin_info.LineLocation('test_filename', 1)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+ self.assertEqual(origin.loc.lineno, 1)
+
+ # The untraced line, inserted second.
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertFalse(loc in source_map)
+
+ # The traced line, inserted first.
+ loc = origin_info.LineLocation('test_filename', 3)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, ' if x > 0:')
+ self.assertEqual(origin.loc.lineno, 2)
+
+ loc = origin_info.LineLocation('test_filename', 4)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, ' if x > 0:')
+ self.assertEqual(origin.loc.lineno, 2)
+
+ def test_resolve(self):
+
+ def test_fn(x):
+ """Docstring."""
+ return x # comment
+
+ node, source = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ origin_info.resolve(fn_node, source)
+
+ origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 1)
+ self.assertEqual(origin.loc.col_offset, 0)
+ self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+
+ origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(origin.loc.col_offset, 2)
+ self.assertEqual(origin.source_code_line, ' """Docstring."""')
+
+ origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 3)
+ self.assertEqual(origin.loc.col_offset, 2)
+ self.assertEqual(origin.source_code_line, ' return x # comment')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py
index c961efa892..112ed46a1e 100644
--- a/tensorflow/contrib/autograph/pyct/parser.py
+++ b/tensorflow/contrib/autograph/pyct/parser.py
@@ -37,6 +37,7 @@ def parse_entity(entity):
def parse_str(src):
"""Returns the AST of given piece of code."""
+ # TODO(mdan): This should exclude the module things are autowrapped in.
return gast.parse(src)