aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-31 10:36:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 10:39:25 -0700
commit1c498283e01137923b526d12ac77d3788f4fa912 (patch)
tree1e1138e306e7ce029f0889aaafea9a0a316601d9 /tensorflow/contrib/autograph
parent7acebcce1839dc3a22bf27d0814e80ea68cedde3 (diff)
Remove an unneeded check and fix an off by one error in the rewritten line numbers. Adds tests to check for this and other basic error rewriting cases.
PiperOrigin-RevId: 206786091
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/core/errors.py146
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD12
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/errors_test.py162
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py2
4 files changed, 240 insertions, 82 deletions
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py
index c219b372c1..5a57d57e7d 100644
--- a/tensorflow/contrib/autograph/core/errors.py
+++ b/tensorflow/contrib/autograph/core/errors.py
@@ -33,8 +33,6 @@ import traceback
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.python.framework import errors_impl
-from tensorflow.python.util import tf_inspect
-
# TODO(mdan): Add a superclass common to all errors.
@@ -68,47 +66,29 @@ class TfRuntimeError(Exception):
return message + ''.join(traceback.format_list(self.custom_traceback))
-def _rewrite_tb(source_map, tb, filter_function_name=None):
+def _rewrite_tb(source_map, tb):
"""Rewrites code references in a traceback.
Args:
source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
locations to their origin
tb: List[Tuple[Text, Text, Text, Text]], consistent with
- traceback.extract_tb
- filter_function_name: Optional[Text], allows restricting restricts the
- frames to rewrite to a particular function name
+ traceback.extract_tb.
Returns:
List[Tuple[Text, Text, Text, Text]], the rewritten traceback
"""
new_tb = []
for frame in tb:
- filename, lineno, function_name, _ = frame
+ filename, lineno, _, _ = frame
loc = origin_info.LineLocation(filename, lineno)
origin = source_map.get(loc)
- # TODO(mdan): We shouldn't need the function name at all.
- # filename + lineno should be sufficient, even if there are multiple source
- # maps.
if origin is not None:
- if filter_function_name == function_name or filter_function_name is None:
- new_tb.append(origin.as_frame())
- else:
- new_tb.append(frame)
+ new_tb.append(origin.as_frame())
else:
new_tb.append(frame)
return new_tb
-# TODO(znado): Make more robust to name changes in the rewriting logic.
-def _remove_rewrite_frames(tb):
- """Remove stack frames containing the error rewriting logic."""
- cleaned_tb = []
- for f in tb:
- if 'ag__.rewrite_graph_construction_error' not in f[3]:
- cleaned_tb.append(f)
- return cleaned_tb
-
-
# TODO(mdan): rename to raise_*
def rewrite_graph_construction_error(source_map):
"""Rewrites errors raised by non-AG APIs inside AG generated code.
@@ -132,20 +112,17 @@ def rewrite_graph_construction_error(source_map):
_, original_error, e_traceback = error_info
assert original_error is not None
try:
- _, _, _, func_name, _, _ = tf_inspect.stack()[1]
+ current_traceback = _cut_traceback_loops(source_map,
+ traceback.extract_tb(e_traceback))
if isinstance(original_error, GraphConstructionError):
# TODO(mdan): This is incomplete.
# The error might have bubbled through a non-converted function.
- cleaned_traceback = traceback.extract_tb(e_traceback)
previous_traceback = original_error.custom_traceback
- cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
+ cleaned_traceback = [current_traceback[0]] + previous_traceback
else:
- cleaned_traceback = traceback.extract_tb(e_traceback)
+ cleaned_traceback = current_traceback
- # Remove the frame corresponding to this function call.
- cleaned_traceback = cleaned_traceback[1:]
-
- cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name)
+ cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
if isinstance(original_error, GraphConstructionError):
original_error.custom_traceback = cleaned_traceback
@@ -163,6 +140,60 @@ def rewrite_graph_construction_error(source_map):
del e_traceback
+def _cut_traceback_loops(source_map, original_traceback):
+ """Check for cases where we leave a user method and re-enter it.
+
+ This is done by looking at the function names when the filenames are from any
+ files the user code is in. If we find a case where we return to a user method
+ after leaving it then we cut out the frames in between because we assume this
+ means these in between frames are from internal AutoGraph code that shouldn't
+ be included.
+
+ An example of this is:
+
+ File "file1.py", line 57, in my_func
+ ...
+ File "control_flow_ops.py", line 231, in cond
+ ...
+ File "control_flow_ops.py", line 1039, in inner_cond
+ ...
+ File "file1.py", line 68, in my_func
+ ...
+
+ Where we would remove the control_flow_ops.py frames because we re-enter
+ my_func in file1.py.
+
+ The source map keys are (file_path, line_number) so get the set of all user
+ file_paths.
+
+ Args:
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
+ locations to their origin
+ original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with
+ traceback.extract_tb.
+
+ Returns:
+ List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed.
+ """
+ all_user_files = set(loc.filename for loc in source_map)
+ cleaned_traceback = []
+ last_user_frame_index = None
+ last_user_user_file_path = None
+ # TODO(mdan): Simplify this logic.
+ for fi, frame in enumerate(original_traceback):
+ frame_file_path, lineno, _, _ = frame
+ src_map_key = origin_info.LineLocation(frame_file_path, lineno)
+ if frame_file_path in all_user_files:
+ if src_map_key in source_map:
+ if (last_user_frame_index is not None and
+ last_user_user_file_path == frame_file_path):
+ cleaned_traceback = cleaned_traceback[:last_user_frame_index]
+ last_user_frame_index = fi
+ last_user_user_file_path = frame_file_path
+ cleaned_traceback.append(frame)
+ return cleaned_traceback
+
+
# TODO(mdan): This should be consistent with rewrite_graph_construction_error
# Both should either raise or return.
def rewrite_tf_runtime_error(error, source_map):
@@ -175,56 +206,9 @@ def rewrite_tf_runtime_error(error, source_map):
Returns:
TfRuntimeError, the rewritten underlying error.
"""
- # Check for cases where we leave a user method and re-enter it in the
- # traceback. This is done by looking at the function names when the
- # filenames are from any files the user code is in. If we find a case where
- # we return to a user method after leaving it then we cut out the frames in
- # between because we assume this means these in between frames are from
- # internal AutoGraph code that shouldn't be included.
- #
- # An example of this is:
- #
- # File "file1.py", line 57, in my_func
- # ...
- # File "control_flow_ops.py", line 231, in cond
- # ...
- # File "control_flow_ops.py", line 1039, in inner_cond
- # ...
- # File "file1.py", line 68, in my_func
- # ...
- #
- # Where we would remove the control_flow_ops.py frames because we re-enter
- # my_func in file1.py.
- #
- # The source map keys are (file_path, line_number) so get the set of all user
- # file_paths.
try:
- all_user_files = set(loc.filename for loc in source_map)
- cleaned_traceback = []
- last_user_frame_index = None
- last_user_user_file_path = None
- last_user_user_fn_name = None
- # TODO(mdan): Simplify this logic.
- for fi, frame in enumerate(error.op.traceback):
- frame_file_path, lineno, _, _ = frame
- lineno -= 1 # Frame line numbers are 1-based.
- src_map_key = origin_info.LineLocation(frame_file_path, lineno)
- if frame_file_path in all_user_files:
- if src_map_key in source_map:
- original_fn_name = source_map[src_map_key].function_name
- if (last_user_frame_index is not None and
- last_user_user_file_path == frame_file_path):
- if last_user_user_fn_name == original_fn_name:
- cleaned_traceback = cleaned_traceback[:last_user_frame_index]
- else:
- cleaned_traceback = cleaned_traceback[:last_user_frame_index + 1]
- last_user_user_fn_name = original_fn_name
- else:
- last_user_user_fn_name = None
- last_user_frame_index = fi
- last_user_user_file_path = frame_file_path
- cleaned_traceback.append(frame)
-
+ cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
+ # cleaned_traceback = error.op.traceback
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
op_name = error.op.name
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
index d20c17b63b..0ab4e2eb5e 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -17,6 +17,18 @@ filegroup(
)
py_test(
+ name = "errors_test",
+ srcs = [
+ "errors_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
name = "keras_test",
srcs = [
"keras_test.py",
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
new file mode 100644
index 0000000000..f4b9159942
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
@@ -0,0 +1,162 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Error traceback rewriting integration tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib import autograph as ag
+from tensorflow.python.util import tf_inspect
+
+
+class ErrorsTest(tf.test.TestCase):
+
+ def test_graph_construction_error_rewriting_call_tree(self):
+
+ def innermost(x):
+ if x > 0:
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
+ return tf.zeros((2, 3))
+
+ def inner_caller():
+ return innermost(1.0)
+
+ def caller():
+ return inner_caller()
+
+ with self.assertRaises(ag.GraphConstructionError) as error:
+ graph = ag.to_graph(caller)
+ graph()
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ found_correct_filename = False
+ num_innermost_names = 0
+ num_inner_caller_names = 0
+ num_caller_names = 0
+ ag_output_filename = tf_inspect.getsourcefile(graph)
+ for frame in custom_traceback:
+ filename, _, fn_name, _ = frame
+ self.assertFalse('control_flow_ops.py' in filename)
+ self.assertFalse(ag_output_filename in filename)
+ found_correct_filename |= __file__ in filename
+ self.assertNotEqual('tf__test_fn', fn_name)
+ num_innermost_names += int('innermost' == fn_name)
+ self.assertNotEqual('tf__inner_caller', fn_name)
+ num_inner_caller_names += int('inner_caller' == fn_name)
+ self.assertNotEqual('tf__caller', fn_name)
+ num_caller_names += int('caller' == fn_name)
+ self.assertTrue(found_correct_filename)
+ self.assertEqual(num_innermost_names, 1)
+ self.assertEqual(num_inner_caller_names, 1)
+ self.assertEqual(num_caller_names, 1)
+
+ def test_graph_construction_error_rewriting_class(self):
+
+ class TestClass(object):
+
+ def test_fn(self):
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
+
+ def inner_caller(self):
+ return self.test_fn()
+
+ def caller(self):
+ return self.inner_caller()
+
+ # Note we expect a TypeError here because the traceback will not be
+ # rewritten for classes.
+ with self.assertRaises(TypeError):
+ graph = ag.to_graph(TestClass)
+ graph().caller()
+
+ def test_runtime_error_rewriting(self):
+
+ def g(x, s):
+ while tf.reduce_sum(x) > s:
+ x //= 0
+ return x
+
+ def test_fn(x):
+ return g(x, 10)
+
+ compiled_fn = ag.to_graph(test_fn)
+
+ with self.assertRaises(ag.TfRuntimeError) as error:
+ with self.test_session() as sess:
+ x = compiled_fn(tf.constant([4, 8]))
+ with ag.improved_errors(compiled_fn):
+ sess.run(x)
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ found_correct_filename = False
+ num_test_fn_frames = 0
+ num_g_frames = 0
+ ag_output_filename = tf_inspect.getsourcefile(compiled_fn)
+ for frame in custom_traceback:
+ filename, _, fn_name, source_code = frame
+ self.assertFalse(ag_output_filename in filename)
+ self.assertFalse('control_flow_ops.py' in filename)
+ self.assertFalse('ag__.' in fn_name)
+ self.assertFalse('tf__g' in fn_name)
+ self.assertFalse('tf__test_fn' in fn_name)
+ found_correct_filename |= __file__ in filename
+ num_test_fn_frames += int('test_fn' == fn_name and
+ 'return g(x, 10)' in source_code)
+ # This makes sure that the code is correctly rewritten from "x_1 //= 0" to
+ # "x //= 0".
+ num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
+ self.assertTrue(found_correct_filename)
+ self.assertEqual(num_test_fn_frames, 1)
+ self.assertEqual(num_g_frames, 1)
+
+ def test_runtime_error_rewriting_nested(self):
+
+ def test_fn(x):
+
+ def g(y):
+ return y**2 // 0
+
+ s = 0
+ for xi in x:
+ s += g(xi)
+ return s
+
+ compiled_fn = ag.to_graph(test_fn)
+
+ # TODO(b/111408261): Nested functions currently do not rewrite correctly,
+ # when they do we should change this test to check for the same traceback
+ # properties as the other tests. This should throw a runtime error with a
+ # frame with "g" as the function name but because we don't yet add
+ # try/except blocks to inner functions the name is "tf__g".
+ with self.assertRaises(ag.TfRuntimeError) as error:
+ with self.test_session() as sess:
+ x = compiled_fn(tf.constant([4, 8]))
+ with ag.improved_errors(compiled_fn):
+ sess.run(x)
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ num_tf_g_frames = 0
+ for frame in custom_traceback:
+ _, _, fn_name, _ = frame
+ self.assertNotEqual('g', fn_name)
+ num_tf_g_frames += int('tf__g' == fn_name)
+ self.assertEqual(num_tf_g_frames, 1)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
index 1aad2f47df..9f98e48a6a 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -162,7 +162,7 @@ def resolve(nodes, source, function=None):
source_code_line = source_lines[lineno_in_body - 1]
if function:
- source_lineno = function_lineno + lineno_in_body
+ source_lineno = function_lineno + lineno_in_body - 1
function_name = function.__name__
else:
source_lineno = lineno_in_body