aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/error_interpolation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/error_interpolation.py')
-rw-r--r--tensorflow/python/framework/error_interpolation.py164
1 files changed, 160 insertions, 4 deletions
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index 9ccae76147..a79073b748 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -24,11 +24,15 @@ from __future__ import print_function
import collections
import itertools
+import os
import re
import string
import six
+from tensorflow.python.util import tf_stack
+
+
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
@@ -38,6 +42,11 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
+_BAD_FILE_SUBSTRINGS = [
+ os.path.join("tensorflow", "python"),
+ "<embedded",
+]
+
def _parse_message(message):
"""Parses the message.
@@ -48,6 +57,12 @@ def _parse_message(message):
"123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
three separators. The separators are the numeric characters.
+ Supported tags after node:<node_name>
+ file: Replaced with the filename in which the node was defined.
+ line: Replaced by the line number at which the node was defined.
+ colocations: Replaced by a multi-line message describing the file and
+ line numbers at which this node was colocated with other nodes.
+
Args:
message: String to parse
@@ -72,9 +87,135 @@ def _parse_message(message):
return seps, tags
-# TODO(jtkeeling): Modify to actually interpolate format strings rather than
-# echoing them.
-def interpolate(error_message):
+def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
+ """Return a summary of an op's colocation stack.
+
+ Args:
+ colocation_dict: The op._colocation_dict.
+ prefix: An optional string prefix used before each line of the multi-
+ line string returned by this function.
+
+ Returns:
+ A multi-line string similar to:
+ Node-device colocations active during op creation:
+ with tf.colocate_with(test_node_1): <test_1.py:27>
+ with tf.colocate_with(test_node_2): <test_2.py:38>
+ The first line will have no padding to its left by default. Subsequent
+ lines will have two spaces of left-padding. Use the prefix argument
+ to increase indentation.
+ """
+ if not colocation_dict:
+ message = "No node-device colocations were active during op creation."
+ return prefix + message
+
+ str_list = []
+ str_list.append("%sNode-device colocations active during op creation:"
+ % prefix)
+
+ for name, location in colocation_dict.items():
+ location_summary = "<{file}:{line}>".format(file=location.filename,
+ line=location.lineno)
+ subs = {
+ "prefix": prefix,
+ "indent": " ",
+ "name": name,
+ "loc": location_summary,
+ }
+ str_list.append(
+ "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
+
+ return "\n".join(str_list)
+
+
+def _compute_colocation_summary_from_op(op, prefix=""):
+ """Fetch colocation file, line, and nesting and return a summary string."""
+ if not op:
+ return ""
+ # pylint: disable=protected-access
+ return _compute_colocation_summary_from_dict(op._colocation_dict, prefix)
+ # pylint: enable=protected-access
+
+
+def _find_index_of_defining_frame_for_op(op):
+ """Return index in op._traceback with first 'useful' frame.
+
+ This method reads through the stack stored in op._traceback looking for the
+ innermost frame which (hopefully) belongs to the caller. It accomplishes this
+ by rejecting frames whose filename appears to come from TensorFlow (see
+ error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
+
+ Args:
+ op: the Operation object for which we would like to find the defining
+ location.
+
+ Returns:
+ Integer index into op._traceback where the first non-TF file was found
+ (innermost to outermost), or 0 (for the outermost stack frame) if all files
+ came from TensorFlow.
+ """
+ # pylint: disable=protected-access
+ # Index 0 of tf_traceback is the outermost frame.
+ tf_traceback = tf_stack.convert_stack(op._traceback)
+ size = len(tf_traceback)
+ # pylint: enable=protected-access
+ filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback]
+ # We process the filenames from the innermost frame to outermost.
+ for idx, filename in enumerate(reversed(filenames)):
+ contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
+ if not any(contains_bad_substrings):
+ return size - idx - 1
+ return 0
+
+
+def _get_defining_frame_from_op(op):
+ """Find and return stack frame where op was defined."""
+ frame = None
+ if op:
+ # pylint: disable=protected-access
+ frame_index = _find_index_of_defining_frame_for_op(op)
+ frame = op._traceback[frame_index]
+ # pylint: enable=protected-access
+ return frame
+
+
+def _compute_field_dict(op):
+ """Return a dictionary mapping interpolation tokens to values.
+
+ Args:
+ op: op.Operation object having a _traceback member.
+
+ Returns:
+ A dictionary mapping string tokens to string values. The keys are shown
+ below along with example values.
+ {
+ "file": "tool_utils.py",
+ "line": "124",
+ "colocations":
+ '''Node-device colocations active during op creation:
+ with tf.colocate_with(test_node_1): <test_1.py:27>
+ with tf.colocate_with(test_node_2): <test_2.py:38>'''
+ }
+ If op is None or lacks a _traceback field, the returned values will be
+ "<NA>".
+ """
+ default_value = "<NA>"
+ field_dict = {
+ "file": default_value,
+ "line": default_value,
+ "colocations": default_value,
+ }
+ frame = _get_defining_frame_from_op(op)
+ if frame:
+ field_dict["file"] = frame[tf_stack.TB_FILENAME]
+ field_dict["line"] = frame[tf_stack.TB_LINENO]
+ colocation_summary = _compute_colocation_summary_from_op(op)
+ if colocation_summary:
+ field_dict["colocations"] = colocation_summary
+
+ return field_dict
+
+
+def interpolate(error_message, graph):
"""Interpolates an error message.
The error message can contain tags of the form ^^type:name:format^^ which will
@@ -82,11 +223,26 @@ def interpolate(error_message):
Args:
error_message: A string to interpolate.
+ graph: ops.Graph object containing all nodes referenced in the error
+ message.
Returns:
The string with tags of the form ^^type:name:format^^ interpolated.
"""
seps, tags = _parse_message(error_message)
- subs = [string.Template(tag.format).safe_substitute({}) for tag in tags]
+
+ node_name_to_substitution_dict = {}
+ for name in [t.name for t in tags]:
+ try:
+ op = graph.get_operation_by_name(name)
+ except KeyError:
+ op = None
+
+ node_name_to_substitution_dict[name] = _compute_field_dict(op)
+
+ subs = [
+ string.Template(tag.format).safe_substitute(
+ node_name_to_substitution_dict[tag.name]) for tag in tags
+ ]
return "".join(
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))