diff options
Diffstat (limited to 'tensorflow/python/framework/error_interpolation_test.py')
-rw-r--r-- | tensorflow/python/framework/error_interpolation_test.py | 203 |
1 files changed, 193 insertions, 10 deletions
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index ad448deb62..1e5cb73854 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -18,31 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import error_interpolation +from tensorflow.python.framework import ops +from tensorflow.python.framework import traceable_stack from tensorflow.python.platform import test +from tensorflow.python.util import tf_stack + + +def _make_frame_with_filename(op, idx, filename): + """Return a copy of an existing stack frame with a new filename.""" + stack_frame = list(op._traceback[idx]) + stack_frame[tf_stack.TB_FILENAME] = filename + return tuple(stack_frame) + + +def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, + num_inner_tf_frames): + """Replace op._traceback with a new traceback using special filenames.""" + tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0] + user_filename = os.path.join("%d", "my_favorite_file.py") + + num_requested_frames = num_user_frames + num_inner_tf_frames + num_actual_frames = len(op._traceback) + num_outer_frames = num_actual_frames - num_requested_frames + assert num_requested_frames <= num_actual_frames, "Too few real frames." + + # The op's traceback has outermost frame at index 0. + stack = [] + for idx in range(0, num_outer_frames): + stack.append(op._traceback[idx]) + for idx in range(len(stack), len(stack)+num_user_frames): + stack.append(_make_frame_with_filename(op, idx, user_filename % idx)) + for idx in range(len(stack), len(stack)+num_inner_tf_frames): + stack.append(_make_frame_with_filename(op, idx, tf_filename % idx)) + op._traceback = stack + + +def assert_node_in_colocation_summary(test_obj, colocation_summary_string, + name, filename="", lineno=""): + lineno = str(lineno) + name_phrase = "colocate_with(%s)" % name + for term in [name_phrase, filename, lineno]: + test_obj.assertIn(term, colocation_summary_string) + test_obj.assertNotIn("loc:@", colocation_summary_string) + + +class ComputeColocationSummaryFromOpTest(test.TestCase): + + def testCorrectFormatWithActiveColocations(self): + t_obj_1 = traceable_stack.TraceableObject(None, + filename="test_1.py", + lineno=27) + t_obj_2 = traceable_stack.TraceableObject(None, + filename="test_2.py", + lineno=38) + colocation_dict = { + "test_node_1": t_obj_1, + "test_node_2": t_obj_2, + } + summary = error_interpolation._compute_colocation_summary_from_dict( + colocation_dict, prefix=" ") + assert_node_in_colocation_summary(self, + summary, + name="test_node_1", + filename="test_1.py", + lineno=27) + assert_node_in_colocation_summary(self, summary, + name="test_node_2", + filename="test_2.py", + lineno=38) + + def testCorrectFormatWhenNoColocationsWereActive(self): + colocation_dict = {} + summary = error_interpolation._compute_colocation_summary_from_dict( + colocation_dict, prefix=" ") + self.assertIn("No node-device colocations", summary) class InterpolateTest(test.TestCase): + def setUp(self): + # Add nodes to the graph for retrieval by name later. + constant_op.constant(1, name="One") + constant_op.constant(2, name="Two") + three = constant_op.constant(3, name="Three") + self.graph = three.graph + + # Change the list of bad file substrings so that constant_op.py is chosen + # as the defining stack frame for constant_op.constant ops. + self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS + error_interpolation._BAD_FILE_SUBSTRINGS = [ + "%sops.py" % os.sep, + "%sutil" % os.sep, + ] + + def tearDown(self): + error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings + + def testFindIndexOfDefiningFrameForOp(self): + local_op = constant_op.constant(42).op + user_filename = "hope.py" + _modify_op_stack_with_filenames(local_op, + num_user_frames=3, + user_filename=user_filename, + num_inner_tf_frames=5) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + # Expected frame is 6th from the end because there are 5 inner frames witih + # TF filenames. + expected_frame = len(local_op._traceback) - 6 + self.assertEqual(expected_frame, idx) + + def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): + local_op = constant_op.constant(43).op + # Truncate stack to known length. + local_op._traceback = local_op._traceback[:7] + # Ensure all frames look like TF frames. + _modify_op_stack_with_filenames(local_op, + num_user_frames=0, + user_filename="user_file.py", + num_inner_tf_frames=7) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + self.assertEqual(0, idx) + def testNothingToDo(self): normal_string = "This is just a normal string" - interpolated_string = error_interpolation.interpolate(normal_string) + interpolated_string = error_interpolation.interpolate(normal_string, + self.graph) self.assertEqual(interpolated_string, normal_string) def testOneTag(self): - one_tag_string = "^^node:Foo:${file}^^" - interpolated_string = error_interpolation.interpolate(one_tag_string) - self.assertEqual(interpolated_string, "${file}") + one_tag_string = "^^node:Two:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertTrue(interpolated_string.endswith("constant_op.py"), + "interpolated_string '%s' did not end with constant_op.py" + % interpolated_string) + + def testOneTagWithAFakeNameResultsInPlaceholders(self): + one_tag_string = "^^node:MinusOne:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertEqual(interpolated_string, "<NA>") def testTwoTagsNoSeps(self): - two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^" - interpolated_string = error_interpolation.interpolate(two_tags_no_seps) - self.assertEqual(interpolated_string, "${file}${line}") + two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^" + interpolated_string = error_interpolation.interpolate(two_tags_no_seps, + self.graph) + self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+") def testTwoTagsWithSeps(self): - two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789" - interpolated_string = error_interpolation.interpolate(two_tags_with_seps) - self.assertEqual(interpolated_string, "123${file}456${line}789") + two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;" + interpolated_string = error_interpolation.interpolate(two_tags_with_seps, + self.graph) + expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$" + self.assertRegexpMatches(interpolated_string, expected_regex) + + +class InterpolateColocationSummaryTest(test.TestCase): + + def setUp(self): + # Add nodes to the graph for retrieval by name later. + node_one = constant_op.constant(1, name="One") + node_two = constant_op.constant(2, name="Two") + + # node_three has one colocation group, obviously. + with ops.colocate_with(node_one): + node_three = constant_op.constant(3, name="Three_with_one") + + # node_four has one colocation group even though three is (transitively) + # colocated with one. + with ops.colocate_with(node_three): + constant_op.constant(4, name="Four_with_three") + + # node_five has two colocation groups because one and two are not colocated. + with ops.colocate_with(node_two): + with ops.colocate_with(node_one): + constant_op.constant(5, name="Five_with_one_with_two") + + self.graph = node_three.graph + + def testNodeThreeHasColocationInterpolation(self): + message = "^^node:Three_with_one:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="One") + + def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): + message = "^^node:Four_with_three:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="Three_with_one") + self.assertNotIn( + "One", result, + "Node One should not appear in Four_with_three's summary:\n%s" + % result) + + def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): + message = "^^node:Five_with_one_with_two:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="One") + assert_node_in_colocation_summary(self, result, name="Two") + + def testColocationInterpolationForNodeLackingColocation(self): + message = "^^node:One:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("No node-device colocations", result) + self.assertNotIn("One", result) + self.assertNotIn("Two", result) if __name__ == "__main__": |