diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-31 10:36:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 10:39:25 -0700 |
commit | 1c498283e01137923b526d12ac77d3788f4fa912 (patch) | |
tree | 1e1138e306e7ce029f0889aaafea9a0a316601d9 /tensorflow/contrib/autograph | |
parent | 7acebcce1839dc3a22bf27d0814e80ea68cedde3 (diff) |
Remove an unneeded check and fix an off by one error in the rewritten line numbers. Adds tests to check for this and other basic error rewriting cases.
PiperOrigin-RevId: 206786091
Diffstat (limited to 'tensorflow/contrib/autograph')
4 files changed, 240 insertions, 82 deletions
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py index c219b372c1..5a57d57e7d 100644 --- a/tensorflow/contrib/autograph/core/errors.py +++ b/tensorflow/contrib/autograph/core/errors.py @@ -33,8 +33,6 @@ import traceback from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.python.framework import errors_impl -from tensorflow.python.util import tf_inspect - # TODO(mdan): Add a superclass common to all errors. @@ -68,47 +66,29 @@ class TfRuntimeError(Exception): return message + ''.join(traceback.format_list(self.custom_traceback)) -def _rewrite_tb(source_map, tb, filter_function_name=None): +def _rewrite_tb(source_map, tb): """Rewrites code references in a traceback. Args: source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping locations to their origin tb: List[Tuple[Text, Text, Text, Text]], consistent with - traceback.extract_tb - filter_function_name: Optional[Text], allows restricting restricts the - frames to rewrite to a particular function name + traceback.extract_tb. Returns: List[Tuple[Text, Text, Text, Text]], the rewritten traceback """ new_tb = [] for frame in tb: - filename, lineno, function_name, _ = frame + filename, lineno, _, _ = frame loc = origin_info.LineLocation(filename, lineno) origin = source_map.get(loc) - # TODO(mdan): We shouldn't need the function name at all. - # filename + lineno should be sufficient, even if there are multiple source - # maps. if origin is not None: - if filter_function_name == function_name or filter_function_name is None: - new_tb.append(origin.as_frame()) - else: - new_tb.append(frame) + new_tb.append(origin.as_frame()) else: new_tb.append(frame) return new_tb -# TODO(znado): Make more robust to name changes in the rewriting logic. -def _remove_rewrite_frames(tb): - """Remove stack frames containing the error rewriting logic.""" - cleaned_tb = [] - for f in tb: - if 'ag__.rewrite_graph_construction_error' not in f[3]: - cleaned_tb.append(f) - return cleaned_tb - - # TODO(mdan): rename to raise_* def rewrite_graph_construction_error(source_map): """Rewrites errors raised by non-AG APIs inside AG generated code. @@ -132,20 +112,17 @@ def rewrite_graph_construction_error(source_map): _, original_error, e_traceback = error_info assert original_error is not None try: - _, _, _, func_name, _, _ = tf_inspect.stack()[1] + current_traceback = _cut_traceback_loops(source_map, + traceback.extract_tb(e_traceback)) if isinstance(original_error, GraphConstructionError): # TODO(mdan): This is incomplete. # The error might have bubbled through a non-converted function. - cleaned_traceback = traceback.extract_tb(e_traceback) previous_traceback = original_error.custom_traceback - cleaned_traceback = [cleaned_traceback[0]] + previous_traceback + cleaned_traceback = [current_traceback[0]] + previous_traceback else: - cleaned_traceback = traceback.extract_tb(e_traceback) + cleaned_traceback = current_traceback - # Remove the frame corresponding to this function call. - cleaned_traceback = cleaned_traceback[1:] - - cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name) + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) if isinstance(original_error, GraphConstructionError): original_error.custom_traceback = cleaned_traceback @@ -163,6 +140,60 @@ def rewrite_graph_construction_error(source_map): del e_traceback +def _cut_traceback_loops(source_map, original_traceback): + """Check for cases where we leave a user method and re-enter it. + + This is done by looking at the function names when the filenames are from any + files the user code is in. If we find a case where we return to a user method + after leaving it then we cut out the frames in between because we assume this + means these in between frames are from internal AutoGraph code that shouldn't + be included. + + An example of this is: + + File "file1.py", line 57, in my_func + ... + File "control_flow_ops.py", line 231, in cond + ... + File "control_flow_ops.py", line 1039, in inner_cond + ... + File "file1.py", line 68, in my_func + ... + + Where we would remove the control_flow_ops.py frames because we re-enter + my_func in file1.py. + + The source map keys are (file_path, line_number) so get the set of all user + file_paths. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + + Returns: + List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. + """ + all_user_files = set(loc.filename for loc in source_map) + cleaned_traceback = [] + last_user_frame_index = None + last_user_user_file_path = None + # TODO(mdan): Simplify this logic. + for fi, frame in enumerate(original_traceback): + frame_file_path, lineno, _, _ = frame + src_map_key = origin_info.LineLocation(frame_file_path, lineno) + if frame_file_path in all_user_files: + if src_map_key in source_map: + if (last_user_frame_index is not None and + last_user_user_file_path == frame_file_path): + cleaned_traceback = cleaned_traceback[:last_user_frame_index] + last_user_frame_index = fi + last_user_user_file_path = frame_file_path + cleaned_traceback.append(frame) + return cleaned_traceback + + # TODO(mdan): This should be consistent with rewrite_graph_construction_error # Both should either raise or return. def rewrite_tf_runtime_error(error, source_map): @@ -175,56 +206,9 @@ def rewrite_tf_runtime_error(error, source_map): Returns: TfRuntimeError, the rewritten underlying error. """ - # Check for cases where we leave a user method and re-enter it in the - # traceback. This is done by looking at the function names when the - # filenames are from any files the user code is in. If we find a case where - # we return to a user method after leaving it then we cut out the frames in - # between because we assume this means these in between frames are from - # internal AutoGraph code that shouldn't be included. - # - # An example of this is: - # - # File "file1.py", line 57, in my_func - # ... - # File "control_flow_ops.py", line 231, in cond - # ... - # File "control_flow_ops.py", line 1039, in inner_cond - # ... - # File "file1.py", line 68, in my_func - # ... - # - # Where we would remove the control_flow_ops.py frames because we re-enter - # my_func in file1.py. - # - # The source map keys are (file_path, line_number) so get the set of all user - # file_paths. try: - all_user_files = set(loc.filename for loc in source_map) - cleaned_traceback = [] - last_user_frame_index = None - last_user_user_file_path = None - last_user_user_fn_name = None - # TODO(mdan): Simplify this logic. - for fi, frame in enumerate(error.op.traceback): - frame_file_path, lineno, _, _ = frame - lineno -= 1 # Frame line numbers are 1-based. - src_map_key = origin_info.LineLocation(frame_file_path, lineno) - if frame_file_path in all_user_files: - if src_map_key in source_map: - original_fn_name = source_map[src_map_key].function_name - if (last_user_frame_index is not None and - last_user_user_file_path == frame_file_path): - if last_user_user_fn_name == original_fn_name: - cleaned_traceback = cleaned_traceback[:last_user_frame_index] - else: - cleaned_traceback = cleaned_traceback[:last_user_frame_index + 1] - last_user_user_fn_name = original_fn_name - else: - last_user_user_fn_name = None - last_user_frame_index = fi - last_user_user_file_path = frame_file_path - cleaned_traceback.append(frame) - + cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) + # cleaned_traceback = error.op.traceback cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) op_name = error.op.name diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD index d20c17b63b..0ab4e2eb5e 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD +++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD @@ -17,6 +17,18 @@ filegroup( ) py_test( + name = "errors_test", + srcs = [ + "errors_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( name = "keras_test", srcs = [ "keras_test.py", diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py new file mode 100644 index 0000000000..f4b9159942 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error traceback rewriting integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph as ag +from tensorflow.python.util import tf_inspect + + +class ErrorsTest(tf.test.TestCase): + + def test_graph_construction_error_rewriting_call_tree(self): + + def innermost(x): + if x > 0: + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + return tf.zeros((2, 3)) + + def inner_caller(): + return innermost(1.0) + + def caller(): + return inner_caller() + + with self.assertRaises(ag.GraphConstructionError) as error: + graph = ag.to_graph(caller) + graph() + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_innermost_names = 0 + num_inner_caller_names = 0 + num_caller_names = 0 + ag_output_filename = tf_inspect.getsourcefile(graph) + for frame in custom_traceback: + filename, _, fn_name, _ = frame + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse(ag_output_filename in filename) + found_correct_filename |= __file__ in filename + self.assertNotEqual('tf__test_fn', fn_name) + num_innermost_names += int('innermost' == fn_name) + self.assertNotEqual('tf__inner_caller', fn_name) + num_inner_caller_names += int('inner_caller' == fn_name) + self.assertNotEqual('tf__caller', fn_name) + num_caller_names += int('caller' == fn_name) + self.assertTrue(found_correct_filename) + self.assertEqual(num_innermost_names, 1) + self.assertEqual(num_inner_caller_names, 1) + self.assertEqual(num_caller_names, 1) + + def test_graph_construction_error_rewriting_class(self): + + class TestClass(object): + + def test_fn(self): + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + + def inner_caller(self): + return self.test_fn() + + def caller(self): + return self.inner_caller() + + # Note we expect a TypeError here because the traceback will not be + # rewritten for classes. + with self.assertRaises(TypeError): + graph = ag.to_graph(TestClass) + graph().caller() + + def test_runtime_error_rewriting(self): + + def g(x, s): + while tf.reduce_sum(x) > s: + x //= 0 + return x + + def test_fn(x): + return g(x, 10) + + compiled_fn = ag.to_graph(test_fn) + + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_test_fn_frames = 0 + num_g_frames = 0 + ag_output_filename = tf_inspect.getsourcefile(compiled_fn) + for frame in custom_traceback: + filename, _, fn_name, source_code = frame + self.assertFalse(ag_output_filename in filename) + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse('ag__.' in fn_name) + self.assertFalse('tf__g' in fn_name) + self.assertFalse('tf__test_fn' in fn_name) + found_correct_filename |= __file__ in filename + num_test_fn_frames += int('test_fn' == fn_name and + 'return g(x, 10)' in source_code) + # This makes sure that the code is correctly rewritten from "x_1 //= 0" to + # "x //= 0". + num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) + self.assertTrue(found_correct_filename) + self.assertEqual(num_test_fn_frames, 1) + self.assertEqual(num_g_frames, 1) + + def test_runtime_error_rewriting_nested(self): + + def test_fn(x): + + def g(y): + return y**2 // 0 + + s = 0 + for xi in x: + s += g(xi) + return s + + compiled_fn = ag.to_graph(test_fn) + + # TODO(b/111408261): Nested functions currently do not rewrite correctly, + # when they do we should change this test to check for the same traceback + # properties as the other tests. This should throw a runtime error with a + # frame with "g" as the function name but because we don't yet add + # try/except blocks to inner functions the name is "tf__g". + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + num_tf_g_frames = 0 + for frame in custom_traceback: + _, _, fn_name, _ = frame + self.assertNotEqual('g', fn_name) + num_tf_g_frames += int('tf__g' == fn_name) + self.assertEqual(num_tf_g_frames, 1) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index 1aad2f47df..9f98e48a6a 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -162,7 +162,7 @@ def resolve(nodes, source, function=None): source_code_line = source_lines[lineno_in_body - 1] if function: - source_lineno = function_lineno + lineno_in_body + source_lineno = function_lineno + lineno_in_body - 1 function_name = function.__name__ else: source_lineno = lineno_in_body |