aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug/cli
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-25 17:01:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-25 18:28:30 -0700
commit7c4549bc393c0b8be2390e49528a8e207617d66e (patch)
treea8e9b27bbd85cb7aa60257ed0d3522de5427522f /tensorflow/python/debug/cli
parent84fe4b527ea51fb674f99c3c418bfe9d74dade98 (diff)
Automated rollback of change 154225030
Change: 154247352
Diffstat (limited to 'tensorflow/python/debug/cli')
-rw-r--r--tensorflow/python/debug/cli/cli_shared.py10
-rw-r--r--tensorflow/python/debug/cli/cli_shared_test.py15
-rw-r--r--tensorflow/python/debug/cli/command_parser.py25
-rw-r--r--tensorflow/python/debug/cli/command_parser_test.py19
-rw-r--r--tensorflow/python/debug/cli/profile_analyzer_cli.py459
-rw-r--r--tensorflow/python/debug/cli/profile_analyzer_cli_test.py264
6 files changed, 789 insertions, 3 deletions
diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py
index 8ff0916761..3deb6dbad6 100644
--- a/tensorflow/python/debug/cli/cli_shared.py
+++ b/tensorflow/python/debug/cli/cli_shared.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
+
import numpy as np
import six
@@ -73,6 +75,14 @@ def bytes_to_readable_str(num_bytes, include_b=False):
return result
+def time_to_readable_str(value):
+ if not value:
+ return "0"
+ suffixes = ["us", "ms", "s"]
+ order = min(len(suffixes) - 1, int(math.log(value, 10) / 3))
+ return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order])
+
+
def parse_ranges_highlight(ranges_string):
"""Process ranges highlight string.
diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py
index 1ef3c34254..fde1d66998 100644
--- a/tensorflow/python/debug/cli/cli_shared_test.py
+++ b/tensorflow/python/debug/cli/cli_shared_test.py
@@ -70,6 +70,21 @@ class BytesToReadableStrTest(test_util.TensorFlowTestCase):
1024**3, include_b=True))
+class TimeToReadableStrTest(test_util.TensorFlowTestCase):
+
+ def testNoneTimeWorks(self):
+ self.assertEqual("0", cli_shared.time_to_readable_str(None))
+
+ def testMicrosecondsTime(self):
+ self.assertEqual("40us", cli_shared.time_to_readable_str(40))
+
+ def testMillisecondTime(self):
+ self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
+
+ def testSecondTime(self):
+ self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
+
+
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
def setUp(self):
diff --git a/tensorflow/python/debug/cli/command_parser.py b/tensorflow/python/debug/cli/command_parser.py
index a71982f86a..143c104519 100644
--- a/tensorflow/python/debug/cli/command_parser.py
+++ b/tensorflow/python/debug/cli/command_parser.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import ast
-from collections import namedtuple
import re
import sys
@@ -29,8 +28,28 @@ _WHITESPACE_PATTERN = re.compile(r"\s+")
_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
-Interval = namedtuple("Interval",
- ["start", "start_included", "end", "end_included"])
+
+class Interval(object):
+ """Represents an interval between a start and end value."""
+
+ def __init__(self, start, start_included, end, end_included):
+ self.start = start
+ self.start_included = start_included
+ self.end = end
+ self.end_included = end_included
+
+ def contains(self, value):
+ if value < self.start or value == self.start and not self.start_included:
+ return False
+ if value > self.end or value == self.end and not self.end_included:
+ return False
+ return True
+
+ def __eq__(self, other):
+ return (self.start == other.start and
+ self.start_included == other.start_included and
+ self.end == other.end and
+ self.end_included == other.end_included)
def parse_command(command):
diff --git a/tensorflow/python/debug/cli/command_parser_test.py b/tensorflow/python/debug/cli/command_parser_test.py
index ab9b3245dc..1ea890be8c 100644
--- a/tensorflow/python/debug/cli/command_parser_test.py
+++ b/tensorflow/python/debug/cli/command_parser_test.py
@@ -490,6 +490,25 @@ class ParseInterval(test_util.TensorFlowTestCase):
"equal to end of interval."):
command_parser.parse_memory_interval("[5k, 3k]")
+ def testIntervalContains(self):
+ interval = command_parser.Interval(
+ start=1, start_included=True, end=10, end_included=True)
+ self.assertTrue(interval.contains(1))
+ self.assertTrue(interval.contains(10))
+ self.assertTrue(interval.contains(5))
+
+ interval.start_included = False
+ self.assertFalse(interval.contains(1))
+ self.assertTrue(interval.contains(10))
+
+ interval.end_included = False
+ self.assertFalse(interval.contains(1))
+ self.assertFalse(interval.contains(10))
+
+ interval.start_included = True
+ self.assertTrue(interval.contains(1))
+ self.assertFalse(interval.contains(10))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli.py b/tensorflow/python/debug/cli/profile_analyzer_cli.py
new file mode 100644
index 0000000000..42440521eb
--- /dev/null
+++ b/tensorflow/python/debug/cli/profile_analyzer_cli.py
@@ -0,0 +1,459 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Formats and displays profiling information."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import re
+
+from tensorflow.python.debug.cli import cli_shared
+from tensorflow.python.debug.cli import command_parser
+from tensorflow.python.debug.cli import debugger_cli_common
+from tensorflow.python.debug.cli import ui_factory
+from tensorflow.python.debug.lib import source_utils
+
+
+SORT_OPS_BY_OP_NAME = "node"
+SORT_OPS_BY_OP_TIME = "op_time"
+SORT_OPS_BY_EXEC_TIME = "exec_time"
+SORT_OPS_BY_START_TIME = "start_time"
+SORT_OPS_BY_LINE = "line"
+
+
+class ProfileDatum(object):
+ """Profile data point."""
+
+ def __init__(self, node_exec_stats, file_line, op_type):
+ """Constructor.
+
+ Args:
+ node_exec_stats: `NodeExecStats` proto.
+ file_line: A `string` formatted as <file_name>:<line_number>.
+ op_type: (string) Operation type.
+ """
+ self.node_exec_stats = node_exec_stats
+ self.file_line = file_line
+ self.op_type = op_type
+ self.op_time = (self.node_exec_stats.op_end_rel_micros -
+ self.node_exec_stats.op_start_rel_micros)
+
+ @property
+ def exec_time(self):
+ """Measures compute function exection time plus pre- and post-processing."""
+ return self.node_exec_stats.all_end_rel_micros
+
+
+class ProfileDataTableView(object):
+ """Table View of profiling data."""
+
+ def __init__(self, profile_datum_list):
+ """Constructor.
+
+ Args:
+ profile_datum_list: List of `ProfileDatum` objects.
+ """
+ self._profile_datum_list = profile_datum_list
+ self.formatted_op_time = [
+ cli_shared.time_to_readable_str(datum.op_time)
+ for datum in profile_datum_list]
+ self.formatted_exec_time = [
+ cli_shared.time_to_readable_str(
+ datum.node_exec_stats.all_end_rel_micros)
+ for datum in profile_datum_list]
+ self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
+ SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
+
+ def value(self, row, col):
+ if col == 0:
+ return self._profile_datum_list[row].node_exec_stats.node_name
+ elif col == 1:
+ return self.formatted_op_time[row]
+ elif col == 2:
+ return self.formatted_exec_time[row]
+ elif col == 3:
+ return self._profile_datum_list[row].file_line
+ else:
+ raise IndexError("Invalid column index %d." % col)
+
+ def row_count(self):
+ return len(self._profile_datum_list)
+
+ def column_count(self):
+ return 4
+
+ def column_names(self):
+ return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"]
+
+ def column_sort_id(self, col):
+ return self._column_sort_ids[col]
+
+
+def _list_profile_filter(
+ profile_datum, node_name_regex, file_name_regex, op_type_regex,
+ op_time_interval, exec_time_interval):
+ """Filter function for list_profile command.
+
+ Args:
+ profile_datum: A `ProfileDatum` object.
+ node_name_regex: Regular expression pattern object to filter by name.
+ file_name_regex: Regular expression pattern object to filter by file.
+ op_type_regex: Regular expression pattern object to filter by op type.
+ op_time_interval: `Interval` for filtering op time.
+ exec_time_interval: `Interval` for filtering exec time.
+
+ Returns:
+ True if profile_datum should be included.
+ """
+ if not node_name_regex.match(
+ profile_datum.node_exec_stats.node_name):
+ return False
+ if profile_datum.file_line is not None and not file_name_regex.match(
+ profile_datum.file_line):
+ return False
+ if profile_datum.op_type is not None and not op_type_regex.match(
+ profile_datum.op_type):
+ return False
+ if op_time_interval is not None and not op_time_interval.contains(
+ profile_datum.op_time):
+ return False
+ if exec_time_interval and not exec_time_interval.contains(
+ profile_datum.node_exec_stats.all_end_rel_micros):
+ return False
+ return True
+
+
+def _list_profile_sort_key(profile_datum, sort_by):
+ """Get a profile_datum property to sort by in list_profile command.
+
+ Args:
+ profile_datum: A `ProfileDatum` object.
+ sort_by: (string) indicates a value to sort by.
+ Must be one of SORT_BY* constants.
+
+ Returns:
+ profile_datum property to sort by.
+ """
+ if sort_by == SORT_OPS_BY_OP_NAME:
+ return profile_datum.node_exec_stats.node_name
+ elif sort_by == SORT_OPS_BY_LINE:
+ return profile_datum.file_line
+ elif sort_by == SORT_OPS_BY_OP_TIME:
+ return profile_datum.op_time
+ elif sort_by == SORT_OPS_BY_EXEC_TIME:
+ return profile_datum.node_exec_stats.all_end_rel_micros
+ else: # sort by start time
+ return profile_datum.node_exec_stats.all_start_micros
+
+
+class ProfileAnalyzer(object):
+ """Analyzer for profiling data."""
+
+ def __init__(self, graph, run_metadata):
+ """ProfileAnalyzer constructor.
+
+ Args:
+ graph: (tf.Graph) Python graph object.
+ run_metadata: A `RunMetadata` protobuf object.
+
+ Raises:
+ ValueError: If run_metadata is None.
+ """
+ self._graph = graph
+ if not run_metadata:
+ raise ValueError("No RunMetadata passed for profile analysis.")
+ self._run_metadata = run_metadata
+ self._arg_parsers = {}
+ ap = argparse.ArgumentParser(
+ description="List nodes profile information.",
+ usage=argparse.SUPPRESS)
+ ap.add_argument(
+ "-d",
+ "--device_name_filter",
+ dest="device_name_filter",
+ type=str,
+ default="",
+ help="filter device name by regex.")
+ ap.add_argument(
+ "-n",
+ "--node_name_filter",
+ dest="node_name_filter",
+ type=str,
+ default="",
+ help="filter node name by regex.")
+ ap.add_argument(
+ "-t",
+ "--op_type_filter",
+ dest="op_type_filter",
+ type=str,
+ default="",
+ help="filter op type by regex.")
+ # TODO(annarev): allow file filtering at non-stack top position.
+ ap.add_argument(
+ "-f",
+ "--file_name_filter",
+ dest="file_name_filter",
+ type=str,
+ default="",
+ help="filter by file name at the top position of node's creation "
+ "stack that does not belong to TensorFlow library.")
+ ap.add_argument(
+ "-e",
+ "--execution_time",
+ dest="execution_time",
+ type=str,
+ default="",
+ help="Filter by execution time interval "
+ "(includes compute plus pre- and post -processing time). "
+ "Supported units are s, ms and us (default). "
+ "E.g. -e >100s, -e <100, -e [100us,1000ms]")
+ ap.add_argument(
+ "-o",
+ "--op_time",
+ dest="op_time",
+ type=str,
+ default="",
+ help="Filter by op time interval (only includes compute time). "
+ "Supported units are s, ms and us (default). "
+ "E.g. -e >100s, -e <100, -e [100us,1000ms]")
+ ap.add_argument(
+ "-s",
+ "--sort_by",
+ dest="sort_by",
+ type=str,
+ default=SORT_OPS_BY_START_TIME,
+ help=("the field to sort the data by: (%s | %s | %s | %s | %s)" %
+ (SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME,
+ SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE)))
+ ap.add_argument(
+ "-r",
+ "--reverse",
+ dest="reverse",
+ action="store_true",
+ help="sort the data in reverse (descending) order")
+
+ self._arg_parsers["list_profile"] = ap
+
+ def list_profile(self, args, screen_info=None):
+ """Command handler for list_profile.
+
+ List per-operation profile information.
+
+ Args:
+ args: Command-line arguments, excluding the command prefix, as a list of
+ str.
+ screen_info: Optional dict input containing screen information such as
+ cols.
+
+ Returns:
+ Output text lines as a RichTextLines object.
+ """
+ del screen_info
+
+ parsed = self._arg_parsers["list_profile"].parse_args(args)
+ op_time_interval = (command_parser.parse_time_interval(parsed.op_time)
+ if parsed.op_time else None)
+ exec_time_interval = (
+ command_parser.parse_time_interval(parsed.execution_time)
+ if parsed.execution_time else None)
+ node_name_regex = re.compile(parsed.node_name_filter)
+ file_name_regex = re.compile(parsed.file_name_filter)
+ op_type_regex = re.compile(parsed.op_type_filter)
+
+ output = debugger_cli_common.RichTextLines([""])
+ device_name_regex = re.compile(parsed.device_name_filter)
+ data_generator = self._get_profile_data_generator()
+ device_count = len(self._run_metadata.step_stats.dev_stats)
+ for index in range(device_count):
+ device_stats = self._run_metadata.step_stats.dev_stats[index]
+ if device_name_regex.match(device_stats.device):
+ profile_data = [
+ datum for datum in data_generator(device_stats)
+ if _list_profile_filter(
+ datum, node_name_regex, file_name_regex, op_type_regex,
+ op_time_interval, exec_time_interval)]
+ profile_data = sorted(
+ profile_data,
+ key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by),
+ reverse=parsed.reverse)
+ output.extend(
+ self._get_list_profile_lines(
+ device_stats.device, index, device_count,
+ profile_data, parsed.sort_by, parsed.reverse))
+ return output
+
+ def _get_profile_data_generator(self):
+ """Get function that generates `ProfileDatum` objects.
+
+ Returns:
+ A function that generates `ProfileDatum` objects.
+ """
+ node_to_file_line = {}
+ node_to_op_type = {}
+ for op in self._graph.get_operations():
+ file_line = ""
+ for trace_entry in reversed(op.traceback):
+ filepath = trace_entry[0]
+ file_line = "%s:%d(%s)" % (
+ os.path.basename(filepath), trace_entry[1], trace_entry[2])
+ if not source_utils.guess_is_tensorflow_py_library(filepath):
+ break
+ node_to_file_line[op.name] = file_line
+ node_to_op_type[op.name] = op.type
+
+ def profile_data_generator(device_step_stats):
+ for node_stats in device_step_stats.node_stats:
+ if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK":
+ continue
+ yield ProfileDatum(
+ node_stats,
+ node_to_file_line.get(node_stats.node_name, ""),
+ node_to_op_type.get(node_stats.node_name, ""))
+ return profile_data_generator
+
+ def _get_list_profile_lines(
+ self, device_name, device_index, device_count,
+ profile_datum_list, sort_by, sort_reverse):
+ """Get `RichTextLines` object for list_profile command for a given device.
+
+ Args:
+ device_name: (string) Device name.
+ device_index: (int) Device index.
+ device_count: (int) Number of devices.
+ profile_datum_list: List of `ProfileDatum` objects.
+ sort_by: (string) Identifier of column to sort. Sort identifier
+ must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
+ SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
+ sort_reverse: (bool) Whether to sort in descending instead of default
+ (ascending) order.
+
+ Returns:
+ `RichTextLines` object containing a table that displays profiling
+ information for each op.
+ """
+ profile_data = ProfileDataTableView(profile_datum_list)
+
+ # Calculate total time early to calculate column widths.
+ total_op_time = sum(datum.op_time for datum in profile_datum_list)
+ total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
+ for datum in profile_datum_list)
+ device_total_row = [
+ "Device Total", cli_shared.time_to_readable_str(total_op_time),
+ cli_shared.time_to_readable_str(total_exec_time)]
+
+ # Calculate column widths.
+ column_widths = [
+ len(column_name) for column_name in profile_data.column_names()]
+ for col in range(len(device_total_row)):
+ column_widths[col] = max(column_widths[col], len(device_total_row[col]))
+ for col in range(len(column_widths)):
+ for row in range(profile_data.row_count()):
+ column_widths[col] = max(
+ column_widths[col], len(str(profile_data.value(row, col))))
+ column_widths[col] += 2 # add margin between columns
+
+ # Add device name.
+ output = debugger_cli_common.RichTextLines(["-"*80])
+ device_row = "Device %d of %d: %s" % (
+ device_index + 1, device_count, device_name)
+ output.extend(debugger_cli_common.RichTextLines([device_row, ""]))
+
+ # Add headers.
+ base_command = "list_profile"
+ attr_segs = {0: []}
+ row = ""
+ for col in range(profile_data.column_count()):
+ column_name = profile_data.column_names()[col]
+ sort_id = profile_data.column_sort_id(col)
+ command = "%s -s %s" % (base_command, sort_id)
+ if sort_by == sort_id and not sort_reverse:
+ command += " -r"
+ curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
+ prev_len = len(row)
+ row += curr_row
+ attr_segs[0].append(
+ (prev_len, prev_len + len(column_name),
+ [debugger_cli_common.MenuItem(None, command), "bold"]))
+
+ output.extend(
+ debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))
+
+ # Add data rows.
+ for row in range(profile_data.row_count()):
+ row_str = ""
+ for col in range(profile_data.column_count()):
+ row_str += ("{:<%d}" % column_widths[col]).format(
+ profile_data.value(row, col))
+ output.extend(debugger_cli_common.RichTextLines([row_str]))
+
+ # Add stat totals.
+ row_str = ""
+ for col in range(len(device_total_row)):
+ row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col])
+ output.extend(debugger_cli_common.RichTextLines(""))
+ output.extend(debugger_cli_common.RichTextLines(row_str))
+ return output
+
+ def _measure_list_profile_column_widths(self, profile_data):
+ """Determine the maximum column widths for each data list.
+
+ Args:
+ profile_data: list of ProfileDatum objects.
+
+ Returns:
+ List of column widths in the same order as columns in data.
+ """
+ num_columns = len(profile_data.column_names())
+ widths = [len(column_name) for column_name in profile_data.column_names()]
+ for row in range(profile_data.row_count()):
+ for col in range(num_columns):
+ widths[col] = max(
+ widths[col], len(str(profile_data.row_values(row)[col])) + 2)
+ return widths
+
+ def get_help(self, handler_name):
+ return self._arg_parsers[handler_name].format_help()
+
+
+def create_profiler_ui(graph,
+ run_metadata,
+ ui_type="curses",
+ on_ui_exit=None):
+ """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
+
+ Args:
+ graph: Python `Graph` object.
+ run_metadata: A `RunMetadata` protobuf object.
+ ui_type: (str) requested UI type, e.g., "curses", "readline".
+ on_ui_exit: (`Callable`) the callback to be called when the UI exits.
+
+ Returns:
+ (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
+ commands and tab-completions registered.
+ """
+
+ analyzer = ProfileAnalyzer(graph, run_metadata)
+
+ cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
+ cli.register_command_handler(
+ "list_profile",
+ analyzer.list_profile,
+ analyzer.get_help("list_profile"),
+ prefix_aliases=["lp"])
+
+ return cli
diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli_test.py b/tensorflow/python/debug/cli/profile_analyzer_cli_test.py
new file mode 100644
index 0000000000..6cc043e1cb
--- /dev/null
+++ b/tensorflow/python/debug/cli/profile_analyzer_cli_test.py
@@ -0,0 +1,264 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for profile_analyzer_cli."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.core.framework import step_stats_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.cli import profile_analyzer_cli
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ProfileAnalyzerTest(test_util.TensorFlowTestCase):
+
+ def testNodeInfoEmpty(self):
+ graph = ops.Graph()
+ run_metadata = config_pb2.RunMetadata()
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
+ prof_output = prof_analyzer.list_profile([]).lines
+ self.assertEquals([""], prof_output)
+
+ def testSingleDevice(self):
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name="Add/123",
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=4)
+
+ node2 = step_stats_pb2.NodeExecStats(
+ node_name="Mul/456",
+ op_start_rel_micros=1,
+ op_end_rel_micros=2,
+ all_end_rel_micros=3)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = "deviceA"
+ device1.node_stats.extend([node1, node2])
+
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = "Add/123"
+ op1.traceback = [("a/b/file1", 10, "some_var")]
+ op1.type = "add"
+ op2 = test.mock.MagicMock()
+ op2.name = "Mul/456"
+ op2.traceback = [("a/b/file1", 11, "some_var")]
+ op2.type = "mul"
+ graph.get_operations.return_value = [op1, op2]
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
+ prof_output = prof_analyzer.list_profile([]).lines
+
+ self._assertAtLeastOneLineMatches(r"Device 1 of 1: deviceA", prof_output)
+ self._assertAtLeastOneLineMatches(r"^Add/123.*2us.*4us", prof_output)
+ self._assertAtLeastOneLineMatches(r"^Mul/456.*1us.*3us", prof_output)
+
+ def testMultipleDevices(self):
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name="Add/123",
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=3)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = "deviceA"
+ device1.node_stats.extend([node1])
+
+ device2 = run_metadata.step_stats.dev_stats.add()
+ device2.device = "deviceB"
+ device2.node_stats.extend([node1])
+
+ graph = test.mock.MagicMock()
+ op = test.mock.MagicMock()
+ op.name = "Add/123"
+ op.traceback = [("a/b/file1", 10, "some_var")]
+ op.type = "abc"
+ graph.get_operations.return_value = [op]
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
+ prof_output = prof_analyzer.list_profile([]).lines
+
+ self._assertAtLeastOneLineMatches(r"Device 1 of 2: deviceA", prof_output)
+ self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
+
+ # Try filtering by device.
+ prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines
+ self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
+ self._assertNoLinesMatch(r"Device 1 of 2: deviceA", prof_output)
+
+ def testWithSession(self):
+ options = config_pb2.RunOptions()
+ options.trace_level = config_pb2.RunOptions.FULL_TRACE
+ run_metadata = config_pb2.RunMetadata()
+
+ with session.Session() as sess:
+ a = constant_op.constant([1, 2, 3])
+ b = constant_op.constant([2, 2, 1])
+ result = math_ops.add(a, b)
+
+ sess.run(result, options=options, run_metadata=run_metadata)
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(
+ sess.graph, run_metadata)
+ prof_output = prof_analyzer.list_profile([]).lines
+
+ self._assertAtLeastOneLineMatches("Device 1 of", prof_output)
+ expected_headers = [
+ "Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"]
+ self._assertAtLeastOneLineMatches(
+ ".*".join(expected_headers), prof_output)
+ self._assertAtLeastOneLineMatches(r"^Add/", prof_output)
+ self._assertAtLeastOneLineMatches(r"Device Total", prof_output)
+
+ def testSorting(self):
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name="Add/123",
+ all_start_micros=123,
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=4)
+
+ node2 = step_stats_pb2.NodeExecStats(
+ node_name="Mul/456",
+ all_start_micros=122,
+ op_start_rel_micros=1,
+ op_end_rel_micros=2,
+ all_end_rel_micros=5)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = "deviceA"
+ device1.node_stats.extend([node1, node2])
+
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = "Add/123"
+ op1.traceback = [("a/b/file2", 10, "some_var")]
+ op1.type = "add"
+ op2 = test.mock.MagicMock()
+ op2.name = "Mul/456"
+ op2.traceback = [("a/b/file1", 11, "some_var")]
+ op2.type = "mul"
+ graph.get_operations.return_value = [op1, op2]
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
+
+ # Default sort by start time (i.e. all_start_micros).
+ prof_output = prof_analyzer.list_profile([]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
+ # Default sort in reverse.
+ prof_output = prof_analyzer.list_profile(["-r"]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
+ # Sort by name.
+ prof_output = prof_analyzer.list_profile(["-s", "node"]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
+ # Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros).
+ prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
+ # Sort by exec time (i.e. all_end_rel_micros).
+ prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
+ # Sort by line number.
+ prof_output = prof_analyzer.list_profile(["-s", "line"]).lines
+ self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
+
+ def testFiltering(self):
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name="Add/123",
+ all_start_micros=123,
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=4)
+
+ node2 = step_stats_pb2.NodeExecStats(
+ node_name="Mul/456",
+ all_start_micros=122,
+ op_start_rel_micros=1,
+ op_end_rel_micros=2,
+ all_end_rel_micros=5)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = "deviceA"
+ device1.node_stats.extend([node1, node2])
+
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = "Add/123"
+ op1.traceback = [("a/b/file2", 10, "some_var")]
+ op1.type = "add"
+ op2 = test.mock.MagicMock()
+ op2.name = "Mul/456"
+ op2.traceback = [("a/b/file1", 11, "some_var")]
+ op2.type = "mul"
+ graph.get_operations.return_value = [op1, op2]
+
+ prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
+
+ # Filter by name
+ prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines
+ self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
+ self._assertNoLinesMatch(r"Mul/456", prof_output)
+ # Filter by op_type
+ prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines
+ self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
+ self._assertNoLinesMatch(r"Add/123", prof_output)
+ # Filter by file name.
+ prof_output = prof_analyzer.list_profile(["-f", "file2"]).lines
+ self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
+ self._assertNoLinesMatch(r"Mul/456", prof_output)
+ # Fitler by execution time.
+ prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
+ self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
+ self._assertNoLinesMatch(r"Add/123", prof_output)
+ # Fitler by op time.
+ prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
+ self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
+ self._assertNoLinesMatch(r"Mul/456", prof_output)
+
+ def _atLeastOneLineMatches(self, pattern, lines):
+ pattern_re = re.compile(pattern)
+ for line in lines:
+ if pattern_re.match(line):
+ return True
+ return False
+
+ def _assertAtLeastOneLineMatches(self, pattern, lines):
+ if not self._atLeastOneLineMatches(pattern, lines):
+ raise AssertionError(
+ "%s does not match any line in %s." % (pattern, str(lines)))
+
+ def _assertNoLinesMatch(self, pattern, lines):
+ if self._atLeastOneLineMatches(pattern, lines):
+ raise AssertionError(
+ "%s matched at least one line in %s." % (pattern, str(lines)))
+
+
+if __name__ == "__main__":
+ googletest.main()