diff options
author | 2017-04-25 13:22:44 -0800 | |
---|---|---|
committer | 2017-04-25 14:48:08 -0700 | |
commit | 9509cda97d3f4af8ffc1b3aff5fbddc158ff7440 (patch) | |
tree | 9388740662b31183e3ce7285261aff1ed92a2b78 | |
parent | a5cc879abb1209f40b9079c49810a95eff326593 (diff) |
Add ProfileAnalyzer class to tfdbg.
Change: 154220704
-rw-r--r-- | tensorflow/python/debug/BUILD | 35 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/cli_shared.py | 10 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/cli_shared_test.py | 15 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/command_parser.py | 25 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/command_parser_test.py | 19 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/profile_analyzer_cli.py | 459 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/profile_analyzer_cli_test.py | 264 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/source_utils.py | 4 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/source_utils_test.py | 8 |
9 files changed, 830 insertions, 9 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index f7e17f1c53..e738c86a1f 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -148,6 +148,22 @@ py_library( ) py_library( + name = "profile_analyzer_cli", + srcs = ["cli/profile_analyzer_cli.py"], + srcs_version = "PY2AND3", + deps = [ + ":cli_shared", + ":command_parser", + ":debug_data", + ":debugger_cli_common", + ":source_utils", + ":ui_factory", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( name = "stepper_cli", srcs = ["cli/stepper_cli.py"], srcs_version = "PY2AND3", @@ -241,6 +257,7 @@ py_library( ":debug_data", ":debugger_cli_common", ":framework", + ":profile_analyzer_cli", ":stepper_cli", ":ui_factory", ], @@ -606,6 +623,24 @@ cuda_py_test( ], ) +py_test( + name = "profile_analyzer_cli_test", + size = "small", + srcs = [ + "cli/profile_analyzer_cli_test.py", + ], + deps = [ + ":command_parser", + ":profile_analyzer_cli", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "stepper_cli_test", size = "small", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index 8ff0916761..3deb6dbad6 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np import six @@ -73,6 +75,14 @@ def bytes_to_readable_str(num_bytes, include_b=False): return result +def time_to_readable_str(value): + if not value: + return "0" + suffixes = ["us", "ms", "s"] + order = min(len(suffixes) - 1, int(math.log(value, 10) / 3)) + return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order]) + + def parse_ranges_highlight(ranges_string): """Process ranges highlight string. diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py index 1ef3c34254..fde1d66998 100644 --- a/tensorflow/python/debug/cli/cli_shared_test.py +++ b/tensorflow/python/debug/cli/cli_shared_test.py @@ -70,6 +70,21 @@ class BytesToReadableStrTest(test_util.TensorFlowTestCase): 1024**3, include_b=True)) +class TimeToReadableStrTest(test_util.TensorFlowTestCase): + + def testNoneTimeWorks(self): + self.assertEqual("0", cli_shared.time_to_readable_str(None)) + + def testMicrosecondsTime(self): + self.assertEqual("40us", cli_shared.time_to_readable_str(40)) + + def testMillisecondTime(self): + self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3)) + + def testSecondTime(self): + self.assertEqual("40s", cli_shared.time_to_readable_str(40e6)) + + class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/python/debug/cli/command_parser.py b/tensorflow/python/debug/cli/command_parser.py index a71982f86a..143c104519 100644 --- a/tensorflow/python/debug/cli/command_parser.py +++ b/tensorflow/python/debug/cli/command_parser.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function import ast -from collections import namedtuple import re import sys @@ -29,8 +28,28 @@ _WHITESPACE_PATTERN = re.compile(r"\s+") _NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?") -Interval = namedtuple("Interval", - ["start", "start_included", "end", "end_included"]) + +class Interval(object): + """Represents an interval between a start and end value.""" + + def __init__(self, start, start_included, end, end_included): + self.start = start + self.start_included = start_included + self.end = end + self.end_included = end_included + + def contains(self, value): + if value < self.start or value == self.start and not self.start_included: + return False + if value > self.end or value == self.end and not self.end_included: + return False + return True + + def __eq__(self, other): + return (self.start == other.start and + self.start_included == other.start_included and + self.end == other.end and + self.end_included == other.end_included) def parse_command(command): diff --git a/tensorflow/python/debug/cli/command_parser_test.py b/tensorflow/python/debug/cli/command_parser_test.py index ab9b3245dc..1ea890be8c 100644 --- a/tensorflow/python/debug/cli/command_parser_test.py +++ b/tensorflow/python/debug/cli/command_parser_test.py @@ -490,6 +490,25 @@ class ParseInterval(test_util.TensorFlowTestCase): "equal to end of interval."): command_parser.parse_memory_interval("[5k, 3k]") + def testIntervalContains(self): + interval = command_parser.Interval( + start=1, start_included=True, end=10, end_included=True) + self.assertTrue(interval.contains(1)) + self.assertTrue(interval.contains(10)) + self.assertTrue(interval.contains(5)) + + interval.start_included = False + self.assertFalse(interval.contains(1)) + self.assertTrue(interval.contains(10)) + + interval.end_included = False + self.assertFalse(interval.contains(1)) + self.assertFalse(interval.contains(10)) + + interval.start_included = True + self.assertTrue(interval.contains(1)) + self.assertFalse(interval.contains(10)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli.py b/tensorflow/python/debug/cli/profile_analyzer_cli.py new file mode 100644 index 0000000000..42440521eb --- /dev/null +++ b/tensorflow/python/debug/cli/profile_analyzer_cli.py @@ -0,0 +1,459 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Formats and displays profiling information.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import re + +from tensorflow.python.debug.cli import cli_shared +from tensorflow.python.debug.cli import command_parser +from tensorflow.python.debug.cli import debugger_cli_common +from tensorflow.python.debug.cli import ui_factory +from tensorflow.python.debug.lib import source_utils + + +SORT_OPS_BY_OP_NAME = "node" +SORT_OPS_BY_OP_TIME = "op_time" +SORT_OPS_BY_EXEC_TIME = "exec_time" +SORT_OPS_BY_START_TIME = "start_time" +SORT_OPS_BY_LINE = "line" + + +class ProfileDatum(object): + """Profile data point.""" + + def __init__(self, node_exec_stats, file_line, op_type): + """Constructor. + + Args: + node_exec_stats: `NodeExecStats` proto. + file_line: A `string` formatted as <file_name>:<line_number>. + op_type: (string) Operation type. + """ + self.node_exec_stats = node_exec_stats + self.file_line = file_line + self.op_type = op_type + self.op_time = (self.node_exec_stats.op_end_rel_micros - + self.node_exec_stats.op_start_rel_micros) + + @property + def exec_time(self): + """Measures compute function exection time plus pre- and post-processing.""" + return self.node_exec_stats.all_end_rel_micros + + +class ProfileDataTableView(object): + """Table View of profiling data.""" + + def __init__(self, profile_datum_list): + """Constructor. + + Args: + profile_datum_list: List of `ProfileDatum` objects. + """ + self._profile_datum_list = profile_datum_list + self.formatted_op_time = [ + cli_shared.time_to_readable_str(datum.op_time) + for datum in profile_datum_list] + self.formatted_exec_time = [ + cli_shared.time_to_readable_str( + datum.node_exec_stats.all_end_rel_micros) + for datum in profile_datum_list] + self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME, + SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE] + + def value(self, row, col): + if col == 0: + return self._profile_datum_list[row].node_exec_stats.node_name + elif col == 1: + return self.formatted_op_time[row] + elif col == 2: + return self.formatted_exec_time[row] + elif col == 3: + return self._profile_datum_list[row].file_line + else: + raise IndexError("Invalid column index %d." % col) + + def row_count(self): + return len(self._profile_datum_list) + + def column_count(self): + return 4 + + def column_names(self): + return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"] + + def column_sort_id(self, col): + return self._column_sort_ids[col] + + +def _list_profile_filter( + profile_datum, node_name_regex, file_name_regex, op_type_regex, + op_time_interval, exec_time_interval): + """Filter function for list_profile command. + + Args: + profile_datum: A `ProfileDatum` object. + node_name_regex: Regular expression pattern object to filter by name. + file_name_regex: Regular expression pattern object to filter by file. + op_type_regex: Regular expression pattern object to filter by op type. + op_time_interval: `Interval` for filtering op time. + exec_time_interval: `Interval` for filtering exec time. + + Returns: + True if profile_datum should be included. + """ + if not node_name_regex.match( + profile_datum.node_exec_stats.node_name): + return False + if profile_datum.file_line is not None and not file_name_regex.match( + profile_datum.file_line): + return False + if profile_datum.op_type is not None and not op_type_regex.match( + profile_datum.op_type): + return False + if op_time_interval is not None and not op_time_interval.contains( + profile_datum.op_time): + return False + if exec_time_interval and not exec_time_interval.contains( + profile_datum.node_exec_stats.all_end_rel_micros): + return False + return True + + +def _list_profile_sort_key(profile_datum, sort_by): + """Get a profile_datum property to sort by in list_profile command. + + Args: + profile_datum: A `ProfileDatum` object. + sort_by: (string) indicates a value to sort by. + Must be one of SORT_BY* constants. + + Returns: + profile_datum property to sort by. + """ + if sort_by == SORT_OPS_BY_OP_NAME: + return profile_datum.node_exec_stats.node_name + elif sort_by == SORT_OPS_BY_LINE: + return profile_datum.file_line + elif sort_by == SORT_OPS_BY_OP_TIME: + return profile_datum.op_time + elif sort_by == SORT_OPS_BY_EXEC_TIME: + return profile_datum.node_exec_stats.all_end_rel_micros + else: # sort by start time + return profile_datum.node_exec_stats.all_start_micros + + +class ProfileAnalyzer(object): + """Analyzer for profiling data.""" + + def __init__(self, graph, run_metadata): + """ProfileAnalyzer constructor. + + Args: + graph: (tf.Graph) Python graph object. + run_metadata: A `RunMetadata` protobuf object. + + Raises: + ValueError: If run_metadata is None. + """ + self._graph = graph + if not run_metadata: + raise ValueError("No RunMetadata passed for profile analysis.") + self._run_metadata = run_metadata + self._arg_parsers = {} + ap = argparse.ArgumentParser( + description="List nodes profile information.", + usage=argparse.SUPPRESS) + ap.add_argument( + "-d", + "--device_name_filter", + dest="device_name_filter", + type=str, + default="", + help="filter device name by regex.") + ap.add_argument( + "-n", + "--node_name_filter", + dest="node_name_filter", + type=str, + default="", + help="filter node name by regex.") + ap.add_argument( + "-t", + "--op_type_filter", + dest="op_type_filter", + type=str, + default="", + help="filter op type by regex.") + # TODO(annarev): allow file filtering at non-stack top position. + ap.add_argument( + "-f", + "--file_name_filter", + dest="file_name_filter", + type=str, + default="", + help="filter by file name at the top position of node's creation " + "stack that does not belong to TensorFlow library.") + ap.add_argument( + "-e", + "--execution_time", + dest="execution_time", + type=str, + default="", + help="Filter by execution time interval " + "(includes compute plus pre- and post -processing time). " + "Supported units are s, ms and us (default). " + "E.g. -e >100s, -e <100, -e [100us,1000ms]") + ap.add_argument( + "-o", + "--op_time", + dest="op_time", + type=str, + default="", + help="Filter by op time interval (only includes compute time). " + "Supported units are s, ms and us (default). " + "E.g. -e >100s, -e <100, -e [100us,1000ms]") + ap.add_argument( + "-s", + "--sort_by", + dest="sort_by", + type=str, + default=SORT_OPS_BY_START_TIME, + help=("the field to sort the data by: (%s | %s | %s | %s | %s)" % + (SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME, + SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE))) + ap.add_argument( + "-r", + "--reverse", + dest="reverse", + action="store_true", + help="sort the data in reverse (descending) order") + + self._arg_parsers["list_profile"] = ap + + def list_profile(self, args, screen_info=None): + """Command handler for list_profile. + + List per-operation profile information. + + Args: + args: Command-line arguments, excluding the command prefix, as a list of + str. + screen_info: Optional dict input containing screen information such as + cols. + + Returns: + Output text lines as a RichTextLines object. + """ + del screen_info + + parsed = self._arg_parsers["list_profile"].parse_args(args) + op_time_interval = (command_parser.parse_time_interval(parsed.op_time) + if parsed.op_time else None) + exec_time_interval = ( + command_parser.parse_time_interval(parsed.execution_time) + if parsed.execution_time else None) + node_name_regex = re.compile(parsed.node_name_filter) + file_name_regex = re.compile(parsed.file_name_filter) + op_type_regex = re.compile(parsed.op_type_filter) + + output = debugger_cli_common.RichTextLines([""]) + device_name_regex = re.compile(parsed.device_name_filter) + data_generator = self._get_profile_data_generator() + device_count = len(self._run_metadata.step_stats.dev_stats) + for index in range(device_count): + device_stats = self._run_metadata.step_stats.dev_stats[index] + if device_name_regex.match(device_stats.device): + profile_data = [ + datum for datum in data_generator(device_stats) + if _list_profile_filter( + datum, node_name_regex, file_name_regex, op_type_regex, + op_time_interval, exec_time_interval)] + profile_data = sorted( + profile_data, + key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by), + reverse=parsed.reverse) + output.extend( + self._get_list_profile_lines( + device_stats.device, index, device_count, + profile_data, parsed.sort_by, parsed.reverse)) + return output + + def _get_profile_data_generator(self): + """Get function that generates `ProfileDatum` objects. + + Returns: + A function that generates `ProfileDatum` objects. + """ + node_to_file_line = {} + node_to_op_type = {} + for op in self._graph.get_operations(): + file_line = "" + for trace_entry in reversed(op.traceback): + filepath = trace_entry[0] + file_line = "%s:%d(%s)" % ( + os.path.basename(filepath), trace_entry[1], trace_entry[2]) + if not source_utils.guess_is_tensorflow_py_library(filepath): + break + node_to_file_line[op.name] = file_line + node_to_op_type[op.name] = op.type + + def profile_data_generator(device_step_stats): + for node_stats in device_step_stats.node_stats: + if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK": + continue + yield ProfileDatum( + node_stats, + node_to_file_line.get(node_stats.node_name, ""), + node_to_op_type.get(node_stats.node_name, "")) + return profile_data_generator + + def _get_list_profile_lines( + self, device_name, device_index, device_count, + profile_datum_list, sort_by, sort_reverse): + """Get `RichTextLines` object for list_profile command for a given device. + + Args: + device_name: (string) Device name. + device_index: (int) Device index. + device_count: (int) Number of devices. + profile_datum_list: List of `ProfileDatum` objects. + sort_by: (string) Identifier of column to sort. Sort identifier + must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME, + SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE. + sort_reverse: (bool) Whether to sort in descending instead of default + (ascending) order. + + Returns: + `RichTextLines` object containing a table that displays profiling + information for each op. + """ + profile_data = ProfileDataTableView(profile_datum_list) + + # Calculate total time early to calculate column widths. + total_op_time = sum(datum.op_time for datum in profile_datum_list) + total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros + for datum in profile_datum_list) + device_total_row = [ + "Device Total", cli_shared.time_to_readable_str(total_op_time), + cli_shared.time_to_readable_str(total_exec_time)] + + # Calculate column widths. + column_widths = [ + len(column_name) for column_name in profile_data.column_names()] + for col in range(len(device_total_row)): + column_widths[col] = max(column_widths[col], len(device_total_row[col])) + for col in range(len(column_widths)): + for row in range(profile_data.row_count()): + column_widths[col] = max( + column_widths[col], len(str(profile_data.value(row, col)))) + column_widths[col] += 2 # add margin between columns + + # Add device name. + output = debugger_cli_common.RichTextLines(["-"*80]) + device_row = "Device %d of %d: %s" % ( + device_index + 1, device_count, device_name) + output.extend(debugger_cli_common.RichTextLines([device_row, ""])) + + # Add headers. + base_command = "list_profile" + attr_segs = {0: []} + row = "" + for col in range(profile_data.column_count()): + column_name = profile_data.column_names()[col] + sort_id = profile_data.column_sort_id(col) + command = "%s -s %s" % (base_command, sort_id) + if sort_by == sort_id and not sort_reverse: + command += " -r" + curr_row = ("{:<%d}" % column_widths[col]).format(column_name) + prev_len = len(row) + row += curr_row + attr_segs[0].append( + (prev_len, prev_len + len(column_name), + [debugger_cli_common.MenuItem(None, command), "bold"])) + + output.extend( + debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs)) + + # Add data rows. + for row in range(profile_data.row_count()): + row_str = "" + for col in range(profile_data.column_count()): + row_str += ("{:<%d}" % column_widths[col]).format( + profile_data.value(row, col)) + output.extend(debugger_cli_common.RichTextLines([row_str])) + + # Add stat totals. + row_str = "" + for col in range(len(device_total_row)): + row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col]) + output.extend(debugger_cli_common.RichTextLines("")) + output.extend(debugger_cli_common.RichTextLines(row_str)) + return output + + def _measure_list_profile_column_widths(self, profile_data): + """Determine the maximum column widths for each data list. + + Args: + profile_data: list of ProfileDatum objects. + + Returns: + List of column widths in the same order as columns in data. + """ + num_columns = len(profile_data.column_names()) + widths = [len(column_name) for column_name in profile_data.column_names()] + for row in range(profile_data.row_count()): + for col in range(num_columns): + widths[col] = max( + widths[col], len(str(profile_data.row_values(row)[col])) + 2) + return widths + + def get_help(self, handler_name): + return self._arg_parsers[handler_name].format_help() + + +def create_profiler_ui(graph, + run_metadata, + ui_type="curses", + on_ui_exit=None): + """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`. + + Args: + graph: Python `Graph` object. + run_metadata: A `RunMetadata` protobuf object. + ui_type: (str) requested UI type, e.g., "curses", "readline". + on_ui_exit: (`Callable`) the callback to be called when the UI exits. + + Returns: + (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer + commands and tab-completions registered. + """ + + analyzer = ProfileAnalyzer(graph, run_metadata) + + cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit) + cli.register_command_handler( + "list_profile", + analyzer.list_profile, + analyzer.get_help("list_profile"), + prefix_aliases=["lp"]) + + return cli diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli_test.py b/tensorflow/python/debug/cli/profile_analyzer_cli_test.py new file mode 100644 index 0000000000..7b34d87c99 --- /dev/null +++ b/tensorflow/python/debug/cli/profile_analyzer_cli_test.py @@ -0,0 +1,264 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for profile_analyzer_cli.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.core.framework import step_stats_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.debug.cli import profile_analyzer_cli +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class ProfileAnalyzerTest(test_util.TensorFlowTestCase): + + def testNodeInfoEmpty(self): + graph = ops.Graph() + run_metadata = config_pb2.RunMetadata() + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) + prof_output = prof_analyzer.list_profile([]).lines + self.assertEquals([""], prof_output) + + def testSingleDevice(self): + node1 = step_stats_pb2.NodeExecStats( + node_name="Add/123", + op_start_rel_micros=3, + op_end_rel_micros=5, + all_end_rel_micros=4) + + node2 = step_stats_pb2.NodeExecStats( + node_name="Mul/456", + op_start_rel_micros=1, + op_end_rel_micros=2, + all_end_rel_micros=3) + + run_metadata = config_pb2.RunMetadata() + device1 = run_metadata.step_stats.dev_stats.add() + device1.device = "deviceA" + device1.node_stats.extend([node1, node2]) + + graph = test.mock.MagicMock() + op1 = test.mock.MagicMock() + op1.name = "Add/123" + op1.traceback = [("a/b/file1", 10, "some_var")] + op1.type = "add" + op2 = test.mock.MagicMock() + op2.name = "Mul/456" + op2.traceback = [("a/b/file1", 11, "some_var")] + op2.type = "mul" + graph.get_operations.return_value = [op1, op2] + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) + prof_output = prof_analyzer.list_profile([]).lines + + self._assertAtLeastOneLineMatches(r"Device 1 of 1: deviceA", prof_output) + self._assertAtLeastOneLineMatches(r"^Add/123.*2us.*4us", prof_output) + self._assertAtLeastOneLineMatches(r"^Mul/456.*1us.*3us", prof_output) + + def testMultipleDevices(self): + node1 = step_stats_pb2.NodeExecStats( + node_name="Add/123", + op_start_rel_micros=3, + op_end_rel_micros=5, + all_end_rel_micros=3) + + run_metadata = config_pb2.RunMetadata() + device1 = run_metadata.step_stats.dev_stats.add() + device1.device = "deviceA" + device1.node_stats.extend([node1]) + + device2 = run_metadata.step_stats.dev_stats.add() + device2.device = "deviceB" + device2.node_stats.extend([node1]) + + graph = test.mock.MagicMock() + op = test.mock.MagicMock() + op.name = "Add/123" + op.traceback = [("a/b/file1", 10, "some_var")] + op.type = "abc" + graph.get_operations.return_value = [op] + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) + prof_output = prof_analyzer.list_profile([]).lines + + self._assertAtLeastOneLineMatches(r"Device 1 of 2: deviceA", prof_output) + self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output) + + # Try filtering by device. + prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines + self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output) + self._assertNoLinesMatch(r"Device 1 of 2: deviceA", prof_output) + + def testWithSession(self): + options = config_pb2.RunOptions() + options.trace_level = config_pb2.RunOptions.FULL_TRACE + run_metadata = config_pb2.RunMetadata() + + with session.Session() as sess: + a = constant_op.constant([1, 2, 3]) + b = constant_op.constant([2, 2, 1]) + result = math_ops.add(a, b) + + sess.run(result, options=options, run_metadata=run_metadata) + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer( + sess.graph, run_metadata) + prof_output = prof_analyzer.list_profile([]).lines + + self._assertAtLeastOneLineMatches("Device 1 of 1:", prof_output) + expected_headers = [ + "Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"] + self._assertAtLeastOneLineMatches( + ".*".join(expected_headers), prof_output) + self._assertAtLeastOneLineMatches(r"^Add/", prof_output) + self._assertAtLeastOneLineMatches(r"Device Total", prof_output) + + def testSorting(self): + node1 = step_stats_pb2.NodeExecStats( + node_name="Add/123", + all_start_micros=123, + op_start_rel_micros=3, + op_end_rel_micros=5, + all_end_rel_micros=4) + + node2 = step_stats_pb2.NodeExecStats( + node_name="Mul/456", + all_start_micros=122, + op_start_rel_micros=1, + op_end_rel_micros=2, + all_end_rel_micros=5) + + run_metadata = config_pb2.RunMetadata() + device1 = run_metadata.step_stats.dev_stats.add() + device1.device = "deviceA" + device1.node_stats.extend([node1, node2]) + + graph = test.mock.MagicMock() + op1 = test.mock.MagicMock() + op1.name = "Add/123" + op1.traceback = [("a/b/file2", 10, "some_var")] + op1.type = "add" + op2 = test.mock.MagicMock() + op2.name = "Mul/456" + op2.traceback = [("a/b/file1", 11, "some_var")] + op2.type = "mul" + graph.get_operations.return_value = [op1, op2] + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) + + # Default sort by start time (i.e. all_start_micros). + prof_output = prof_analyzer.list_profile([]).lines + self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") + # Default sort in reverse. + prof_output = prof_analyzer.list_profile(["-r"]).lines + self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") + # Sort by name. + prof_output = prof_analyzer.list_profile(["-s", "node"]).lines + self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") + # Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros). + prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines + self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") + # Sort by exec time (i.e. all_end_rel_micros). + prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines + self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") + # Sort by line number. + prof_output = prof_analyzer.list_profile(["-s", "line"]).lines + self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") + + def testFiltering(self): + node1 = step_stats_pb2.NodeExecStats( + node_name="Add/123", + all_start_micros=123, + op_start_rel_micros=3, + op_end_rel_micros=5, + all_end_rel_micros=4) + + node2 = step_stats_pb2.NodeExecStats( + node_name="Mul/456", + all_start_micros=122, + op_start_rel_micros=1, + op_end_rel_micros=2, + all_end_rel_micros=5) + + run_metadata = config_pb2.RunMetadata() + device1 = run_metadata.step_stats.dev_stats.add() + device1.device = "deviceA" + device1.node_stats.extend([node1, node2]) + + graph = test.mock.MagicMock() + op1 = test.mock.MagicMock() + op1.name = "Add/123" + op1.traceback = [("a/b/file2", 10, "some_var")] + op1.type = "add" + op2 = test.mock.MagicMock() + op2.name = "Mul/456" + op2.traceback = [("a/b/file1", 11, "some_var")] + op2.type = "mul" + graph.get_operations.return_value = [op1, op2] + + prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) + + # Filter by name + prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines + self._assertAtLeastOneLineMatches(r"Add/123", prof_output) + self._assertNoLinesMatch(r"Mul/456", prof_output) + # Filter by op_type + prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines + self._assertAtLeastOneLineMatches(r"Mul/456", prof_output) + self._assertNoLinesMatch(r"Add/123", prof_output) + # Filter by file name. + prof_output = prof_analyzer.list_profile(["-f", "file2"]).lines + self._assertAtLeastOneLineMatches(r"Add/123", prof_output) + self._assertNoLinesMatch(r"Mul/456", prof_output) + # Fitler by execution time. + prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines + self._assertAtLeastOneLineMatches(r"Mul/456", prof_output) + self._assertNoLinesMatch(r"Add/123", prof_output) + # Fitler by op time. + prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines + self._assertAtLeastOneLineMatches(r"Add/123", prof_output) + self._assertNoLinesMatch(r"Mul/456", prof_output) + + def _atLeastOneLineMatches(self, pattern, lines): + pattern_re = re.compile(pattern) + for line in lines: + if pattern_re.match(line): + return True + return False + + def _assertAtLeastOneLineMatches(self, pattern, lines): + if not self._atLeastOneLineMatches(pattern, lines): + raise AssertionError( + "%s does not match any line in %s." % (pattern, str(lines))) + + def _assertNoLinesMatch(self, pattern, lines): + if self._atLeastOneLineMatches(pattern, lines): + raise AssertionError( + "%s matched at least one line in %s." % (pattern, str(lines))) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py index 580bdc054b..f610d05b83 100644 --- a/tensorflow/python/debug/lib/source_utils.py +++ b/tensorflow/python/debug/lib/source_utils.py @@ -44,7 +44,7 @@ def _convert_watch_key_to_tensor_name(watch_key): return watch_key[:watch_key.rfind(":")] -def _guess_is_tensorflow_py_library(py_file_path): +def guess_is_tensorflow_py_library(py_file_path): """Guess whether a Python source file is a part of the tensorflow library. Special cases: @@ -231,7 +231,7 @@ def list_source_files_against_dump(dump, for file_path in path_to_node_names: output.append(( file_path, - _guess_is_tensorflow_py_library(file_path), + guess_is_tensorflow_py_library(file_path), len(path_to_node_names.get(file_path, {})), len(path_to_tensor_names.get(file_path, {})), path_to_num_dumps.get(file_path, 0), diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index a4fb0d9910..f6195b6a5d 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -57,20 +57,20 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): def testUnitTestFileReturnsFalse(self): self.assertFalse( - source_utils._guess_is_tensorflow_py_library(self.curr_file_path)) + source_utils.guess_is_tensorflow_py_library(self.curr_file_path)) def testSourceUtilModuleReturnsTrue(self): self.assertTrue( - source_utils._guess_is_tensorflow_py_library(source_utils.__file__)) + source_utils.guess_is_tensorflow_py_library(source_utils.__file__)) def testFileInPythonKernelsPathReturnsTrue(self): x = constant_op.constant(42.0, name="x") self.assertTrue( - source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0])) + source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0])) def testNonPythonFileRaisesException(self): with self.assertRaisesRegexp(ValueError, r"is not a Python source file"): - source_utils._guess_is_tensorflow_py_library( + source_utils.guess_is_tensorflow_py_library( os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")) |