diff options
author | Shanqing Cai <cais@google.com> | 2017-03-01 13:47:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-01 14:16:15 -0800 |
commit | 1c5b5e8d8deb10ffc613aff016f4caaed7dc5b30 (patch) | |
tree | ab3e169c8b0d36a9ed14aed5a5ccff438e29e323 | |
parent | 4111527988c4aaf21791df420b58565e49da4e3b (diff) |
tfdbg: Add ability to inspect Python source against TF graphs
Change: 148926382
-rw-r--r-- | tensorflow/docs_src/programmers_guide/debugger.md | 12 | ||||
-rw-r--r-- | tensorflow/python/debug/BUILD | 27 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/analyzer_cli.py | 113 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/analyzer_cli_test.py | 159 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/debugger_cli_common.py | 3 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/debugger_cli_common_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_data.py | 11 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/source_utils.py | 93 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/source_utils_test.py | 203 |
9 files changed, 618 insertions, 9 deletions
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index e66873fa58..19820421d3 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -128,6 +128,9 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively—i.e., the output recipient tree. | | `lt -n softmax.*` | List all dumped tensors whose names match the regular-expression pattern `softmax.*`. | | `lt -t MatMul` | List all dumped tensors whose node type is `MatMul`. | +| `ps /path/to/source.py` | Print the Python source file source.py, with the lines annotated with the ops created at each of them, respectively. | +| `ps -t /path/to/source.py` | Same as the command above, but perform annotation using dumped Tensors, instead of ops. | +| `ps -b 30 /path/to/source.py` | Annotate source.py beginning at line 30. | | `run_info` or `ri` | Display information about the current run, including fetches and feeds. | | `help` | Print general help information listing all available **tfdbg** commands and their flags. | | `help lt` | Print the help information for the `lt` command. | @@ -238,6 +241,9 @@ to show the traceback of the node's construction: tfdbg> ni -t cross_entropy/Log ``` +The `-t` flag is used by default, if you use the clickable "node_info" menu item +at the top of the screen. + From the traceback, you can see that the op is constructed at line 109 of [`debug_mnist.py`](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_mnist.py): @@ -245,6 +251,12 @@ From the traceback, you can see that the op is constructed at line 109 of diff = y_ * tf.log(y) ``` +TIP: tfdbg lets you view a Python source file with its lines annotated with +the ops or Tensors created by them. To use this feature, +simply click the underlined line numbers in the stack trace output of the +`ni -t <op_name>` commands, or use the `ps` (or `print_source`) command such as: +`ps /path/to/source.py` + Apply a value clipping on the input to @{tf.log} to resolve this problem: diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 389c6e8174..929c0dbf42 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -47,6 +47,12 @@ py_library( ) py_library( + name = "source_utils", + srcs = ["lib/source_utils.py"], + srcs_version = "PY2AND3", +) + +py_library( name = "stepper", srcs = ["lib/stepper.py"], srcs_version = "PY2AND3", @@ -121,7 +127,9 @@ py_library( ":command_parser", ":debug_data", ":debugger_cli_common", + ":source_utils", ":ui_factory", + "//third_party/py/numpy", "@six_archive//:six", ], ) @@ -323,6 +331,25 @@ py_test( ], ) +py_test( + name = "source_utils_test", + size = "small", + srcs = [ + "lib/source_utils_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":debug_data", + ":debug_utils", + ":source_utils", + "//tensorflow/python:client", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "stepper_test", size = "small", diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py index 59db992442..29830e86e4 100644 --- a/tensorflow/python/debug/cli/analyzer_cli.py +++ b/tensorflow/python/debug/cli/analyzer_cli.py @@ -27,6 +27,7 @@ import argparse import copy import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.debug.cli import cli_shared @@ -34,7 +35,9 @@ from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import ui_factory from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import source_utils +RL = debugger_cli_common.RichLine # String constants for the depth-dependent hanging indent at the beginning # of each line. @@ -89,7 +92,7 @@ def _add_main_menu(output, menu.append( debugger_cli_common.MenuItem( "node_info", - "node_info -a -d %s" % node_name, + "node_info -a -d -t %s" % node_name, enabled=enable_node_info)) menu.append( debugger_cli_common.MenuItem( @@ -303,7 +306,6 @@ class DebugAnalyzer(object): help="Numerical ranges to highlight tensor elements in. " "Examples: -r 0,1e-8, -r [-0.1,0.1], " "-r \"[[-inf, -0.1], [0.1, inf]]\"") - ap.add_argument( "-a", "--all", @@ -312,6 +314,37 @@ class DebugAnalyzer(object): help="Print the tensor in its entirety, i.e., do not use ellipses.") self._arg_parsers["print_tensor"] = ap + # Parser for print_source. + ap = argparse.ArgumentParser( + description="Print a Python source file with overlaid debug " + "information, including the nodes (ops) or Tensors created at the " + "source lines.", + usage=argparse.SUPPRESS) + ap.add_argument( + "source_file_path", + type=str, + help="Path to the source file.") + ap.add_argument( + "-t", + "--tensors", + dest="tensors", + action="store_true", + help="Label lines with dumped Tensors, instead of ops.") + ap.add_argument( + "-m", + "--max_elements_per_line", + type=int, + default=10, + help="Maximum number of elements (ops or Tensors) to show per source " + "line.") + ap.add_argument( + "-b", + "--line_begin", + type=int, + default=1, + help="Print source beginning at line number (1-based.)") + self._arg_parsers["print_source"] = ap + # TODO(cais): Implement list_nodes. def add_tensor_filter(self, filter_name, filter_callable): @@ -709,15 +742,20 @@ class DebugAnalyzer(object): construction. """ - lines = ["", "", "Traceback of node construction:"] - font_attr_segs = {len(lines) - 1: [(0, len(lines[-1]), "bold")]} + lines = [RL(""), RL(""), RL("Traceback of node construction:", "bold")] try: node_stack = self._debug_dump.node_traceback(node_name) for depth, (file_path, line, function_name, text) in enumerate( node_stack): lines.append("%d: %s" % (depth, file_path)) - lines.append(" Line: %d" % line) + + attribute = debugger_cli_common.MenuItem( + "", "ps %s -b %d" % (file_path, line)) if text else None + line_number_line = RL(" ") + line_number_line += RL("Line: %d" % line, attribute) + lines.append(line_number_line) + lines.append(" Function: %s" % function_name) lines.append(" Text: " + (("\"%s\"" % text) if text else "None")) lines.append("") @@ -726,8 +764,7 @@ class DebugAnalyzer(object): except LookupError: lines.append("(Unavailable because no Python graph has been loaded)") - return debugger_cli_common.RichTextLines(lines, - font_attr_segs=font_attr_segs) + return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) def list_inputs(self, args, screen_info=None): """Command handler for inputs. @@ -942,6 +979,63 @@ class DebugAnalyzer(object): return output + def print_source(self, args, screen_info=None): + """Print the content of a source file.""" + del screen_info # Unused. + + parsed = self._arg_parsers["print_source"].parse_args(args) + + source_annotation = source_utils.annotate_source( + self._debug_dump, + parsed.source_file_path, + do_dumped_tensors=parsed.tensors, + min_line=parsed.line_begin) + + with open(parsed.source_file_path, "rU") as f: + source_text = f.read() + + source_lines = source_text.split("\n") + num_lines = len(source_lines) + line_num_width = int(np.ceil(np.log10(num_lines))) + 3 + + labeled_source_lines = [] + if parsed.line_begin > 1: + labeled_source_lines.append( + RL("(... Omitted %d source lines ...)" % (parsed.line_begin - 1), + "bold")) + + for i, line in enumerate(source_lines[parsed.line_begin - 1:]): + annotated_line = RL("L%d" % (i + parsed.line_begin), "yellow") + annotated_line += " " * (line_num_width - len(annotated_line)) + annotated_line += line + labeled_source_lines.append(annotated_line) + + if i + parsed.line_begin in source_annotation: + sorted_elements = sorted(source_annotation[i + parsed.line_begin]) + for k, element in enumerate(sorted_elements): + if k >= parsed.max_elements_per_line: + labeled_source_lines.append( + " (... Omitted %d of %d %s ...)" % ( + len(sorted_elements) - parsed.max_elements_per_line, + len(sorted_elements), + "tensor(s)" if parsed.tensors else "op(s)")) + break + + label = RL(" " * 4) + if self._debug_dump.debug_watch_keys( + debug_data.get_node_name(element)): + attribute = debugger_cli_common.MenuItem("", "pt %s" % element) + else: + attribute = "blue" + + label += RL(element, attribute) + labeled_source_lines.append(label) + + output = debugger_cli_common.rich_text_lines_from_rich_line_list( + labeled_source_lines) + _add_main_menu(output, node_name=None) + return output + def _list_inputs_or_outputs(self, recursive, node_name, @@ -1292,6 +1386,11 @@ def create_analyzer_ui(debug_dump, tensor_filters=None, ui_type="curses"): analyzer.print_tensor, analyzer.get_help("print_tensor"), prefix_aliases=["pt"]) + cli.register_command_handler( + "print_source", + analyzer.print_source, + analyzer.get_help("print_source"), + prefix_aliases=["ps"]) dumped_tensor_names = [] for datum in debug_dump.dumped_tensor_data: diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index e981fe4f96..454b5e773f 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import inspect +import os import shutil import tempfile @@ -39,6 +41,10 @@ from tensorflow.python.platform import googletest from tensorflow.python.platform import test +def line_number_above(): + return inspect.stack()[1][2] - 1 + + def parse_op_and_node(line): """Parse a line containing an op node followed by a node name. @@ -494,6 +500,9 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): else: cls._main_device = "/job:localhost/replica:0/task:0/cpu:0" + cls._curr_file_path = os.path.abspath( + inspect.getfile(inspect.currentframe())) + cls._sess = session.Session() with cls._sess as sess: u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]]) @@ -502,14 +511,19 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): u_name = "simple_mul_add/u" v_name = "simple_mul_add/v" - u_init = constant_op.constant(u_init_val, shape=[2, 2]) + u_init = constant_op.constant(u_init_val, shape=[2, 2], name="u_init") u = variables.Variable(u_init, name=u_name) - v_init = constant_op.constant(v_init_val, shape=[2, 1]) + cls._u_line_number = line_number_above() + + v_init = constant_op.constant(v_init_val, shape=[2, 1], name="v_init") v = variables.Variable(v_init, name=v_name) + cls._v_line_number = line_number_above() w = math_ops.matmul(u, v, name="simple_mul_add/matmul") + cls._w_line_number = line_number_above() x = math_ops.add(w, w, name="simple_mul_add/add") + cls._x_line_number = line_number_above() u.initializer.run() v.initializer.run() @@ -550,6 +564,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): cls._analyzer.print_tensor, cls._analyzer.get_help("print_tensor"), prefix_aliases=["pt"]) + cls._registry.register_command_handler( + "print_source", + cls._analyzer.print_source, + cls._analyzer.get_help("print_source"), + prefix_aliases=["ps"]) @classmethod def tearDownClass(cls): @@ -1116,6 +1135,142 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): "There is no tensor filter named \"bar\""): analyzer.get_tensor_filter("bar") + def _findSourceLine(self, annotated_source, line_number): + """Find line of given line number in annotated source. + + Args: + annotated_source: (debugger_cli_common.RichTextLines) the annotated source + line_number: (int) 1-based line number + + Returns: + (int) If line_number is found, 0-based line index in + annotated_source.lines. Otherwise, None. + """ + + index = None + for i, line in enumerate(annotated_source.lines): + if line.startswith("L%d " % line_number): + index = i + break + return index + + def testPrintSourceForOpNamesWholeFileWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command( + "print_source", [self._curr_file_path], screen_info={"cols": 80}) + + # Verify the annotation of the line that creates u. + index = self._findSourceLine(out, self._u_line_number) + self.assertEqual( + ["L%d u = variables.Variable(u_init, name=u_name)" % + self._u_line_number, + " simple_mul_add/u", + " simple_mul_add/u/Assign", + " simple_mul_add/u/read"], + out.lines[index : index + 4]) + self.assertEqual("pt simple_mul_add/u", + out.font_attr_segs[index + 1][0][2].content) + # simple_mul_add/u/Assign is not used in this run because the Variable has + # already been initialized. + self.assertEqual("blue", out.font_attr_segs[index + 2][0][2]) + self.assertEqual("pt simple_mul_add/u/read", + out.font_attr_segs[index + 3][0][2].content) + + # Verify the annotation of the line that creates v. + index = self._findSourceLine(out, self._v_line_number) + self.assertEqual( + ["L%d v = variables.Variable(v_init, name=v_name)" % + self._v_line_number, + " simple_mul_add/v"], + out.lines[index : index + 2]) + self.assertEqual("pt simple_mul_add/v", + out.font_attr_segs[index + 1][0][2].content) + + # Verify the annotation of the line that creates w. + index = self._findSourceLine(out, self._w_line_number) + self.assertEqual( + ["L%d " % self._w_line_number + + "w = math_ops.matmul(u, v, name=\"simple_mul_add/matmul\")", + " simple_mul_add/matmul"], + out.lines[index : index + 2]) + self.assertEqual("pt simple_mul_add/matmul", + out.font_attr_segs[index + 1][0][2].content) + + # Verify the annotation of the line that creates x. + index = self._findSourceLine(out, self._x_line_number) + self.assertEqual( + ["L%d " % self._x_line_number + + "x = math_ops.add(w, w, name=\"simple_mul_add/add\")", + " simple_mul_add/add"], + out.lines[index : index + 2]) + self.assertEqual("pt simple_mul_add/add", + out.font_attr_segs[index + 1][0][2].content) + + def testPrintSourceForTensorNamesWholeFileWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command( + "print_source", + [self._curr_file_path, "--tensors"], + screen_info={"cols": 80}) + + # Verify the annotation of the line that creates u. + index = self._findSourceLine(out, self._u_line_number) + self.assertEqual( + ["L%d u = variables.Variable(u_init, name=u_name)" % + self._u_line_number, + " simple_mul_add/u/read:0", + " simple_mul_add/u:0"], + out.lines[index : index + 3]) + self.assertEqual("pt simple_mul_add/u/read:0", + out.font_attr_segs[index + 1][0][2].content) + self.assertEqual("pt simple_mul_add/u:0", + out.font_attr_segs[index + 2][0][2].content) + + def testPrintSourceForOpNamesStartingAtSpecifiedLineWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command( + "print_source", + [self._curr_file_path, "-b", "3"], + screen_info={"cols": 80}) + + self.assertIn("Omitted 2 source lines", out.lines[0]) + self.assertIsNone(self._findSourceLine(out, 1)) + self.assertIsNone(self._findSourceLine(out, 2)) + self.assertIsNotNone(self._findSourceLine(out, 3)) + + index = self._findSourceLine(out, self._u_line_number) + self.assertEqual( + ["L%d u = variables.Variable(u_init, name=u_name)" % + self._u_line_number, + " simple_mul_add/u", + " simple_mul_add/u/Assign", + " simple_mul_add/u/read"], + out.lines[index : index + 4]) + self.assertEqual("pt simple_mul_add/u", + out.font_attr_segs[index + 1][0][2].content) + # simple_mul_add/u/Assign is not used in this run because the Variable has + # already been initialized. + self.assertEqual("blue", out.font_attr_segs[index + 2][0][2]) + self.assertEqual("pt simple_mul_add/u/read", + out.font_attr_segs[index + 3][0][2].content) + + def testPrintSourceForOpNameSettingMaximumElementCountWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command( + "print_source", + [self._curr_file_path, "-m", "1"], + screen_info={"cols": 80}) + + index = self._findSourceLine(out, self._u_line_number) + self.assertEqual( + ["L%d u = variables.Variable(u_init, name=u_name)" % + self._u_line_number, + " simple_mul_add/u", + " (... Omitted 2 of 3 op(s) ...)"], + out.lines[index : index + 3]) + self.assertEqual("pt simple_mul_add/u", + out.font_attr_segs[index + 1][0][2].content) + class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py index b9396f515c..ca896023d4 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common.py +++ b/tensorflow/python/debug/cli/debugger_cli_common.py @@ -104,6 +104,9 @@ class RichLine(object): else: raise TypeError("%r cannot be concatenated with a RichLine" % other) + def __len__(self): + return len(self.text) + def rich_text_lines_from_rich_line_list(rich_text_list): """Convert a list of RichLine objects or strings to a RichTextLines object. diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py index 8c2e5f04dc..1b7a5962fe 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common_test.py +++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py @@ -88,6 +88,12 @@ class RichTextLinesTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(rtl.font_attr_segs[0])) self.assertEqual(1, len(rtl.font_attr_segs[1])) + def testRichLineLenMethodWorks(self): + self.assertEqual(0, len(debugger_cli_common.RichLine())) + self.assertEqual(0, len(debugger_cli_common.RichLine(""))) + self.assertEqual(1, len(debugger_cli_common.RichLine("x"))) + self.assertEqual(6, len(debugger_cli_common.RichLine("x y z ", "blue"))) + def testRichTextLinesConstructorIncomplete(self): # Test RichTextLines constructor, with incomplete keyword arguments. screen_output = debugger_cli_common.RichTextLines( diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 4971810fb0..baaa15abca 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -634,6 +634,17 @@ class DebugDumpDir(object): self._node_traceback[op.name] = op.traceback @property + def python_graph(self): + """Get the Python graph. + + Returns: + If the Python graph has been set, returns a `tf.Graph` object. Otherwise, + returns None. + """ + + return self._python_graph + + @property def core_metadata(self): """Metadata about the `Session.run()` call from the core runtime. diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py new file mode 100644 index 0000000000..9149c0b60b --- /dev/null +++ b/tensorflow/python/debug/lib/source_utils.py @@ -0,0 +1,93 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes and functions that help to inspect Python source w.r.t. TF graphs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def _convert_watch_key_to_tensor_name(watch_key): + return watch_key[:watch_key.rfind(":")] + + +def annotate_source(dump, + source_file_path, + do_dumped_tensors=False, + file_stack_top=False, + min_line=None, + max_line=None): + """Annotate a Python source file with a list of ops created at each line. + + (The annotation doesn't change the source file itself.) + + Args: + dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph + has been loaded. + source_file_path: (`str`) Path to the source file being annotated. + do_dumped_tensors: (`str`) Whether dumped Tensors, instead of ops are to be + used to annotate the source file. + file_stack_top: (`bool`) Whether only the top stack trace in the + specified source file is to be annotated. + min_line: (`None` or `int`) The 1-based line to start annotate the source + file from (inclusive). + max_line: (`None` or `int`) The 1-based line number to end the annotation + at (exclusive). + + Returns: + A `dict` mapping 1-based line number to a list of op name(s) created at + that line, or tensor names if `do_dumped_tensors` is True. + + Raises: + ValueError: If the dump object does not have a Python graph set. + """ + + py_graph = dump.python_graph + if not py_graph: + raise ValueError("Cannot perform source annotation due to a lack of set " + "Python graph in the dump object") + + line_to_op_names = {} + for op in py_graph.get_operations(): + try: + traceback = dump.node_traceback(op.name) + except KeyError: + pass + + for file_path, line_number, _, _ in reversed(traceback): + if (min_line is not None and line_number < min_line or + max_line is not None and line_number >= max_line): + continue + + if file_path != source_file_path: + continue + + if do_dumped_tensors: + watch_keys = dump.debug_watch_keys(op.name) + # Convert watch keys to unique Tensor names. + items_to_append = list( + set(map(_convert_watch_key_to_tensor_name, watch_keys))) + else: + items_to_append = [op.name] + + if line_number in line_to_op_names: + line_to_op_names[line_number].extend(items_to_append) + else: + line_to_op_names[line_number] = items_to_append + + if file_stack_top: + break + + return line_to_op_names diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py new file mode 100644 index 0000000000..5d28bff207 --- /dev/null +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -0,0 +1,203 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for source_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import os +import shutil +import tempfile + +import numpy as np + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_utils +from tensorflow.python.debug.lib import source_utils +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +def line_number_above(): + return inspect.stack()[1][2] - 1 + + +class SourceHelperTest(test_util.TensorFlowTestCase): + + def createAndRunGraphHelper(self): + """Create and run a TensorFlow Graph to generate debug dumps. + + This is intentionally done in separate method, to make it easier to test + the stack-top mode of source annotation. + """ + + self.dump_root = self.get_temp_dir() + self.curr_file_path = os.path.abspath( + inspect.getfile(inspect.currentframe())) + + # Run a simple TF graph to generate some debug dumps that can be used in + # source annotation. + with session.Session() as sess: + self.u_init = constant_op.constant( + np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init") + self.u_init_line_number = line_number_above() + + self.u = variables.Variable(self.u_init, name="u") + self.u_line_number = line_number_above() + + self.v_init = constant_op.constant( + np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init") + self.v_init_line_number = line_number_above() + + self.v = variables.Variable(self.v_init, name="v") + self.v_line_number = line_number_above() + + self.w = math_ops.matmul(self.u, self.v, name="w") + self.w_line_number = line_number_above() + + sess.run(self.u.initializer) + sess.run(self.v.initializer) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph( + run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) + run_metadata = config_pb2.RunMetadata() + sess.run(self.w, options=run_options, run_metadata=run_metadata) + + self.dump = debug_data.DebugDumpDir( + self.dump_root, partition_graphs=run_metadata.partition_graphs) + self.dump.set_python_graph(sess.graph) + + def setUp(self): + self.createAndRunGraphHelper() + self.helper_line_number = line_number_above() + + def tearDown(self): + if os.path.isdir(self.dump_root): + shutil.rmtree(self.dump_root) + ops.reset_default_graph() + + def testAnnotateWholeValidSourceFileGivesCorrectResult(self): + source_annotation = source_utils.annotate_source(self.dump, + self.curr_file_path) + + self.assertIn(self.u_init.op.name, + source_annotation[self.u_init_line_number]) + self.assertIn(self.u.op.name, + source_annotation[self.u_line_number]) + self.assertIn(self.v_init.op.name, + source_annotation[self.v_init_line_number]) + self.assertIn(self.v.op.name, + source_annotation[self.v_line_number]) + self.assertIn(self.w.op.name, + source_annotation[self.w_line_number]) + + # In the non-stack-top (default) mode, the helper line should be annotated + # with all the ops as well. + self.assertIn(self.u_init.op.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.u.op.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.v_init.op.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.v.op.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.w.op.name, + source_annotation[self.helper_line_number]) + + def testAnnotateWithStackTopGivesCorrectResult(self): + source_annotation = source_utils.annotate_source( + self.dump, self.curr_file_path, file_stack_top=True) + + self.assertIn(self.u_init.op.name, + source_annotation[self.u_init_line_number]) + self.assertIn(self.u.op.name, + source_annotation[self.u_line_number]) + self.assertIn(self.v_init.op.name, + source_annotation[self.v_init_line_number]) + self.assertIn(self.v.op.name, + source_annotation[self.v_line_number]) + self.assertIn(self.w.op.name, + source_annotation[self.w_line_number]) + + # In the stack-top mode, the helper line should not have been annotated. + self.assertNotIn(self.helper_line_number, source_annotation) + + def testAnnotateSubsetOfLinesGivesCorrectResult(self): + source_annotation = source_utils.annotate_source( + self.dump, + self.curr_file_path, + min_line=self.u_line_number, + max_line=self.u_line_number + 1) + + self.assertIn(self.u.op.name, + source_annotation[self.u_line_number]) + self.assertNotIn(self.v_line_number, source_annotation) + + def testAnnotateDumpedTensorsGivesCorrectResult(self): + source_annotation = source_utils.annotate_source( + self.dump, self.curr_file_path, do_dumped_tensors=True) + + # Note: Constant Tensors u_init and v_init may not get dumped due to + # constant-folding. + self.assertIn(self.u.name, + source_annotation[self.u_line_number]) + self.assertIn(self.v.name, + source_annotation[self.v_line_number]) + self.assertIn(self.w.name, + source_annotation[self.w_line_number]) + + self.assertNotIn(self.u.op.name, + source_annotation[self.u_line_number]) + self.assertNotIn(self.v.op.name, + source_annotation[self.v_line_number]) + self.assertNotIn(self.w.op.name, + source_annotation[self.w_line_number]) + + self.assertIn(self.u.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.v.name, + source_annotation[self.helper_line_number]) + self.assertIn(self.w.name, + source_annotation[self.helper_line_number]) + + def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self): + self.dump.set_python_graph(None) + with self.assertRaises(ValueError): + source_utils.annotate_source(self.dump, self.curr_file_path) + + def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self): + # Create an unrelated source file. + unrelated_source_path = tempfile.mktemp() + with open(unrelated_source_path, "wt") as source_file: + source_file.write("print('hello, world')\n") + + self.assertEqual( + {}, source_utils.annotate_source(self.dump, unrelated_source_path)) + + # Clean up unrelated source file. + os.remove(unrelated_source_path) + + +if __name__ == "__main__": + googletest.main() |