aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-09-06 08:46:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 08:51:00 -0700
commitfa0a40a51f6f9cce48342d40c19e184470148242 (patch)
tree469b86a68f6e054d754637bdae1a05b4db7c86a2 /tensorflow/python/debug
parent0365fc5606ec36c0a559e50d1d65890de7a76ddd (diff)
tfdbg: Refactor graph-processing code out of debug_data.py
The basic idea is to separate the code in debug_data.py that handles graph structures into its own module (debug_graphs.py). This tackles an existing TODO item to simplify the code debug_data.DebugDumpDir. In a later CL, code will be added to debug_graphs.DebugGraph to allow reconstruction of the original GraphDef, i.e., the GraphDef without the Copy* and Debug* nodes inserted by tfdbg. This will be useful for, among other things, the TensorBoard Debugger Plugin. PiperOrigin-RevId: 167726113
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD34
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli.py16
-rw-r--r--tensorflow/python/debug/lib/debug_data.py543
-rw-r--r--tensorflow/python/debug/lib/debug_data_test.py85
-rw-r--r--tensorflow/python/debug/lib/debug_gradients.py5
-rw-r--r--tensorflow/python/debug/lib/debug_graphs.py430
-rw-r--r--tensorflow/python/debug/lib/debug_graphs_test.py112
-rw-r--r--tensorflow/python/debug/lib/grpc_debug_server.py6
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py2
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py3
-rw-r--r--tensorflow/python/debug/lib/stepper.py7
11 files changed, 681 insertions, 562 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 8eb2212069..c092616999 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -50,10 +50,25 @@ py_library(
)
py_library(
+ name = "debug_graphs",
+ srcs = ["lib/debug_graphs.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:op_def_registry",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_util",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "debug_data",
srcs = ["lib/debug_data.py"],
srcs_version = "PY2AND3",
deps = [
+ ":debug_graphs",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
"//tensorflow/python:op_def_registry",
@@ -70,6 +85,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":debug_data",
+ ":debug_graphs",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
@@ -99,6 +115,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":debug_data",
+ ":debug_graphs",
":debug_utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -181,7 +198,7 @@ py_library(
deps = [
":cli_shared",
":command_parser",
- ":debug_data",
+ ":debug_graphs",
":debugger_cli_common",
":evaluator",
":source_utils",
@@ -401,6 +418,18 @@ py_binary(
)
py_test(
+ name = "debug_graphs_test",
+ size = "small",
+ srcs = ["lib/debug_graphs_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_graphs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "debug_data_test",
size = "small",
srcs = ["lib/debug_data_test.py"],
@@ -569,6 +598,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":debug_data",
+ ":debug_graphs",
":debug_utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -608,7 +638,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":debug_data",
+ ":debug_graphs",
":debug_service_pb2_grpc",
"//tensorflow/core/debug:debug_service_proto_py",
"@six_archive//:six",
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index 22e451e38c..50850bbc0d 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -34,7 +34,7 @@ from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import evaluator
from tensorflow.python.debug.cli import ui_factory
-from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.debug.lib import source_utils
RL = debugger_cli_common.RichLine
@@ -716,7 +716,7 @@ class DebugAnalyzer(object):
# Get a node name, regardless of whether the input is a node name (without
# output slot attached) or a tensor name (with output slot attached).
- node_name, unused_slot = debug_data.parse_node_or_tensor_name(
+ node_name, unused_slot = debug_graphs.parse_node_or_tensor_name(
parsed.node_name)
if not self._debug_dump.node_exists(node_name):
@@ -840,7 +840,7 @@ class DebugAnalyzer(object):
parsed.op_type,
do_outputs=False)
- node_name = debug_data.get_node_name(parsed.node_name)
+ node_name = debug_graphs.get_node_name(parsed.node_name)
_add_main_menu(output, node_name=node_name, enable_list_inputs=False)
return output
@@ -871,7 +871,7 @@ class DebugAnalyzer(object):
tensor_name, tensor_slicing = (
command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
- node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name)
+ node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name)
if (self._debug_dump.loaded_partition_graphs() and
not self._debug_dump.node_exists(node_name)):
output = cli_shared.error(
@@ -1016,7 +1016,7 @@ class DebugAnalyzer(object):
parsed.op_type,
do_outputs=True)
- node_name = debug_data.get_node_name(parsed.node_name)
+ node_name = debug_graphs.get_node_name(parsed.node_name)
_add_main_menu(output, node_name=node_name, enable_list_outputs=False)
return output
@@ -1087,7 +1087,7 @@ class DebugAnalyzer(object):
label = RL(" " * 4)
if self._debug_dump.debug_watch_keys(
- debug_data.get_node_name(element)):
+ debug_graphs.get_node_name(element)):
attribute = debugger_cli_common.MenuItem("", "pt %s" % element)
else:
attribute = cli_shared.COLOR_BLUE
@@ -1246,7 +1246,7 @@ class DebugAnalyzer(object):
font_attr_segs = {}
# Check if this is a tensor name, instead of a node name.
- node_name, _ = debug_data.parse_node_or_tensor_name(node_name)
+ node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name)
# Check if node exists.
if not self._debug_dump.node_exists(node_name):
@@ -1395,7 +1395,7 @@ class DebugAnalyzer(object):
# Recursive call.
# The input's/output's name can be a tensor name, in the case of node
# with >1 output slots.
- inp_node_name, _ = debug_data.parse_node_or_tensor_name(inp)
+ inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp)
self._dfs_from_node(
lines,
attr_segs,
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index b2b3ec5d47..9ea279c004 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -26,14 +26,14 @@ import platform
import numpy as np
import six
-from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.util import event_pb2
-from tensorflow.python.framework import op_def_registry
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@@ -155,30 +155,6 @@ def _load_log_message_from_event_file(event_file_path):
return event.log_message.message
-def parse_node_or_tensor_name(name):
- """Get the node name from a string that can be node or tensor name.
-
- Args:
- name: An input node name (e.g., "node_a") or tensor name (e.g.,
- "node_a:0"), as a str.
-
- Returns:
- 1) The node name, as a str. If the input name is a tensor name, i.e.,
- consists of a colon, the final colon and the following output slot
- will be stripped.
- 2) If the input name is a tensor name, the output slot, as an int. If
- the input name is not a tensor name, None.
- """
-
- if ":" in name and not name.endswith(":"):
- node_name = name[:name.rfind(":")]
- output_slot = int(name[name.rfind(":") + 1:])
-
- return node_name, output_slot
- else:
- return name, None
-
-
def _is_graph_file(file_name):
return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
@@ -191,25 +167,6 @@ def _is_run_feed_keys_info_file(file_name):
return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
-def get_node_name(element_name):
- return element_name.split(":")[0] if ":" in element_name else element_name
-
-
-def get_output_slot(element_name):
- """Get the output slot number from the name of a graph element.
-
- If element_name is a node name without output slot at the end, 0 will be
- assumed.
-
- Args:
- element_name: (`str`) name of the graph element in question.
-
- Returns:
- (`int`) output slot number.
- """
- return int(element_name.split(":")[-1]) if ":" in element_name else 0
-
-
def _get_tensor_name(node_name, output_slot):
"""Get tensor name given node name and output slot index.
@@ -241,78 +198,6 @@ def _get_tensor_watch_key(node_name, output_slot, debug_op):
return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
-def is_copy_node(node_name):
- """Determine whether a node name is that of a debug Copy node.
-
- Such nodes are inserted by TensorFlow core upon request in
- RunOptions.debug_options.debug_tensor_watch_opts.
-
- Args:
- node_name: Name of the node.
-
- Returns:
- A bool indicating whether the input argument is the name of a debug Copy
- node.
- """
- return node_name.startswith("__copy_")
-
-
-def is_debug_node(node_name):
- """Determine whether a node name is that of a debug node.
-
- Such nodes are inserted by TensorFlow core upon request in
- RunOptions.debug_options.debug_tensor_watch_opts.
-
- Args:
- node_name: Name of the node.
-
- Returns:
- A bool indicating whether the input argument is the name of a debug node.
- """
- return node_name.startswith("__dbg_")
-
-
-def parse_debug_node_name(node_name):
- """Parse the name of a debug node.
-
- Args:
- node_name: Name of the debug node.
-
- Returns:
- 1. Name of the watched node, as a str.
- 2. Output slot index of the watched tensor, as an int.
- 3. Index of the debug node, as an int.
- 4. Name of the debug op, as a str, e.g, "DebugIdentity".
-
- Raises:
- ValueError: If the input node name is not a valid debug node name.
- """
- prefix = "__dbg_"
-
- name = node_name
- if not name.startswith(prefix):
- raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
-
- name = name[len(prefix):]
-
- if name.count("_") < 2:
- raise ValueError("Invalid debug node name: '%s'" % node_name)
-
- debug_op = name[name.rindex("_") + 1:]
- name = name[:name.rindex("_")]
-
- debug_op_index = int(name[name.rindex("_") + 1:])
- name = name[:name.rindex("_")]
-
- if name.count(":") != 1:
- raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
-
- watched_node_name = name[:name.index(":")]
- watched_output_slot = int(name[name.index(":") + 1:])
-
- return watched_node_name, watched_output_slot, debug_op_index, debug_op
-
-
def has_inf_or_nan(datum, tensor):
"""A predicate for whether a tensor consists of any bad numerical values.
@@ -573,88 +458,6 @@ class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
pass
-class _GraphTracingReachedDestination(Exception):
- pass
-
-
-class _DFSGraphTracer(object):
- """Graph input tracer using depth-first search."""
-
- def __init__(self,
- input_lists,
- skip_node_names=None,
- destination_node_name=None):
- """Constructor of _DFSGraphTracer.
-
- Args:
- input_lists: A list of dicts. Each dict is an adjacency (input) map from
- the recipient node name as the key and the list of input node names
- as the value.
- skip_node_names: Optional: a list of node names to skip tracing.
- destination_node_name: Optional: destination node name. If not `None`, it
- should be the name of a destination not as a str and the graph tracing
- will raise GraphTracingReachedDestination as soon as the node has been
- reached.
-
- Raises:
- _GraphTracingReachedDestination: if stop_at_node_name is not None and
- the specified node is reached.
- """
-
- self._input_lists = input_lists
- self._skip_node_names = skip_node_names
-
- self._inputs = []
- self._visited_nodes = []
- self._depth_count = 0
- self._depth_list = []
-
- self._destination_node_name = destination_node_name
-
- def trace(self, graph_element_name):
- """Trace inputs.
-
- Args:
- graph_element_name: Name of the node or an output tensor of the node, as a
- str.
-
- Raises:
- _GraphTracingReachedDestination: if destination_node_name of this tracer
- object is not None and the specified node is reached.
- """
- self._depth_count += 1
-
- node_name = get_node_name(graph_element_name)
-
- if node_name == self._destination_node_name:
- raise _GraphTracingReachedDestination()
-
- if node_name in self._skip_node_names:
- return
- if node_name in self._visited_nodes:
- return
-
- self._visited_nodes.append(node_name)
-
- for input_list in self._input_lists:
- for inp in input_list[node_name]:
- if get_node_name(inp) in self._visited_nodes:
- continue
- self._inputs.append(inp)
- self._depth_list.append(self._depth_count)
- self.trace(inp)
-
- self._depth_count -= 1
-
- def inputs(self):
- return self._inputs
-
- def depth_list(self):
- return self._depth_list
-
-
-# TODO(cais): This class is getting too large in line count. Refactor to make it
-# smaller and easier to maintain.
class DebugDumpDir(object):
"""Data set from a debug-dump directory on filesystem.
@@ -963,52 +766,36 @@ class DebugDumpDir(object):
ValueError: If the partition GraphDef of one or more devices fail to be
loaded.
"""
-
- self._node_attributes = {}
- self._node_inputs = {}
- self._node_reversed_ref_inputs = {}
- self._node_ctrl_inputs = {}
- self._node_recipients = {}
- self._node_ctrl_recipients = {}
+ self._debug_graphs = {}
self._node_devices = {}
- self._node_op_types = {}
- self._copy_send_nodes = {}
- self._ref_args = {}
-
- self._partition_graphs = {}
- for device_name in self._device_names:
- partition_graph = None
- if device_name in self._dump_graph_file_paths:
- partition_graph = _load_graph_def_from_event_file(
- self._dump_graph_file_paths[device_name])
- else:
- partition_graph = self._find_partition_graph(partition_graphs,
- device_name)
-
- if partition_graph:
- self._partition_graphs[device_name] = partition_graph
- self._node_attributes[device_name] = {}
- self._node_inputs[device_name] = {}
- self._node_reversed_ref_inputs[device_name] = {}
- self._node_ctrl_inputs[device_name] = {}
- self._node_recipients[device_name] = {}
- self._node_ctrl_recipients[device_name] = {}
- self._node_op_types[device_name] = {}
- self._copy_send_nodes[device_name] = []
- self._ref_args[device_name] = []
-
- if partition_graph:
- for node in partition_graph.node:
- self._process_partition_graph_node(device_name, node)
-
- self._prune_non_control_edges_of_debug_ops(device_name)
- self._prune_control_edges_of_debug_ops(device_name)
+ if partition_graphs:
+ partition_graphs_and_device_names = [
+ (partition_graph, None) for partition_graph in partition_graphs]
+ else:
+ partition_graphs_and_device_names = []
+ for device_name in self._device_names:
+ partition_graph = None
+ if device_name in self._dump_graph_file_paths:
+ partition_graph = _load_graph_def_from_event_file(
+ self._dump_graph_file_paths[device_name])
+ else:
+ partition_graph = self._find_partition_graph(partition_graphs,
+ device_name)
+ if partition_graph:
+ partition_graphs_and_device_names.append((partition_graph,
+ device_name))
+ else:
+ logging.warn("Failed to load partition graphs from disk.")
- self._populate_recipient_maps(device_name)
+ for partition_graph, maybe_device_name in partition_graphs_and_device_names:
+ debug_graph = debug_graphs.DebugGraph(partition_graph,
+ device_name=maybe_device_name)
+ self._debug_graphs[debug_graph.device_name] = debug_graph
+ self._collect_node_devices(debug_graph)
- if device_name in self._partition_graphs and validate:
- self._validate_dump_with_graphs(device_name)
+ if validate and debug_graph.device_name in self._dump_tensor_data:
+ self._validate_dump_with_graphs(debug_graph.device_name)
def _find_partition_graph(self, partition_graphs, device_name):
if partition_graphs is None:
@@ -1020,167 +807,13 @@ class DebugDumpDir(object):
return graph_def
return None
- def _get_ref_args(self, node):
- """Determine whether an input of an op is ref-type.
-
- Args:
- node: A `NodeDef`.
-
- Returns:
- A list of the arg names (as strs) that are ref-type.
- """
-
- op_def = op_def_registry.get_registered_ops().get(node.op)
- ref_args = []
- if op_def:
- for i, output_arg in enumerate(op_def.output_arg):
- if output_arg.is_ref:
- arg_name = node.name if i == 0 else (node.name + ":%d" % i)
- ref_args.append(arg_name)
- return ref_args
-
- def _process_partition_graph_node(self, device_name, node):
- """Process a node from the partition graphs.
-
- Args:
- device_name: (str) device name.
- node: (NodeDef) A partition-graph node to be processed.
-
- Raises:
- ValueError: If duplicate node names are encountered.
- """
-
- if is_debug_node(node.name):
- # This is a debug node. Parse the node name and retrieve the
- # information about debug watches on tensors. But do not include
- # the node in the graph.
- (watched_node_name, watched_output_slot, _,
- debug_op) = parse_debug_node_name(node.name)
-
- self._debug_watches[device_name][watched_node_name][
- watched_output_slot].add(debug_op)
-
- return
-
- if node.name in self._node_inputs[device_name]:
- raise ValueError("Duplicate node name on device %s: '%s'" %
- (device_name, node.name))
-
- self._node_attributes[device_name][node.name] = node.attr
-
- self._node_inputs[device_name][node.name] = []
- self._node_ctrl_inputs[device_name][node.name] = []
- self._node_recipients[device_name][node.name] = []
- self._node_ctrl_recipients[device_name][node.name] = []
-
- if node.name not in self._node_devices:
- self._node_devices[node.name] = set()
- self._node_devices[node.name].add(node.device)
- self._node_op_types[device_name][node.name] = node.op
- self._ref_args[device_name].extend(self._get_ref_args(node))
-
- for inp in node.input:
- if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
- self._copy_send_nodes[device_name].append(node.name)
-
- if inp.startswith("^"):
- cinp = inp[1:]
- self._node_ctrl_inputs[device_name][node.name].append(cinp)
+ def _collect_node_devices(self, debug_graph):
+ for node_name in debug_graph.node_devices:
+ if node_name in self._node_devices:
+ self._node_devices[node_name] = self._node_devices[node_name].union(
+ debug_graph.node_devices[node_name])
else:
- self._node_inputs[device_name][node.name].append(inp)
-
- def _prune_nodes_from_input_and_recipient_maps(self,
- device_name,
- nodes_to_prune):
- """Prune nodes out of input and recipient maps.
-
- Args:
- device_name: (`str`) device name.
- nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
- """
-
- for node in nodes_to_prune:
- del self._node_inputs[device_name][node]
- del self._node_ctrl_inputs[device_name][node]
- del self._node_recipients[device_name][node]
- del self._node_ctrl_recipients[device_name][node]
-
- def _prune_non_control_edges_of_debug_ops(self, device_name):
- """Prune (non-control) edges related to debug ops.
-
- Prune the Copy ops and associated _Send ops inserted by the debugger out
- from the non-control inputs and output recipients map. Replace the inputs
- and recipients with original ones.
-
- Args:
- device_name: (`str`) device name.
- """
-
- copy_nodes = []
- for node in self._node_inputs[device_name]:
- if node in self._copy_send_nodes[device_name]:
- continue
-
- if is_copy_node(node):
- copy_nodes.append(node)
-
- inputs = self._node_inputs[device_name][node]
-
- for i in xrange(len(inputs)):
- inp = inputs[i]
- if is_copy_node(inp):
- # Find the input to the Copy node, which should be the original
- # input to the node.
- orig_inp = self._node_inputs[device_name][inp][0]
- inputs[i] = orig_inp
-
- self._prune_nodes_from_input_and_recipient_maps(device_name, copy_nodes)
- self._prune_nodes_from_input_and_recipient_maps(
- device_name, self._copy_send_nodes[device_name])
-
- def _prune_control_edges_of_debug_ops(self, device_name):
- """Prune control edges related to the debug ops."""
-
- for node in self._node_ctrl_inputs[device_name]:
- ctrl_inputs = self._node_ctrl_inputs[device_name][node]
- debug_op_inputs = []
- for ctrl_inp in ctrl_inputs:
- if is_debug_node(ctrl_inp):
- debug_op_inputs.append(ctrl_inp)
- for debug_op_inp in debug_op_inputs:
- ctrl_inputs.remove(debug_op_inp)
-
- def _populate_recipient_maps(self, device_name):
- """Populate the map from node name to recipient(s) of its output(s).
-
- This method also populates the input map based on reversed ref edges.
-
- Args:
- device_name: name of device.
- """
-
- for node in self._node_inputs[device_name]:
- inputs = self._node_inputs[device_name][node]
- for inp in inputs:
- inp = get_node_name(inp)
- if inp not in self._node_recipients[device_name]:
- self._node_recipients[device_name][inp] = []
- self._node_recipients[device_name][inp].append(node)
-
- if inp in self._ref_args[device_name]:
- if inp not in self._node_reversed_ref_inputs[device_name]:
- self._node_reversed_ref_inputs[device_name][inp] = []
- self._node_reversed_ref_inputs[device_name][inp].append(node)
-
- for node in self._node_ctrl_inputs[device_name]:
- ctrl_inputs = self._node_ctrl_inputs[device_name][node]
- for ctrl_inp in ctrl_inputs:
- if ctrl_inp in self._copy_send_nodes[device_name]:
- continue
-
- if ctrl_inp not in self._node_ctrl_recipients[device_name]:
- self._node_ctrl_recipients[device_name][ctrl_inp] = []
- self._node_ctrl_recipients[device_name][ctrl_inp].append(node)
+ self._node_devices[node_name] = debug_graph.node_devices[node_name]
def _validate_dump_with_graphs(self, device_name):
"""Validate the dumped tensor data against the partition graphs.
@@ -1197,31 +830,31 @@ class DebugDumpDir(object):
Or if the temporal order of the dump's timestamps violate the
input relations on the partition graphs.
"""
-
- if not self._partition_graphs[device_name]:
+ if not self._debug_graphs:
raise LookupError(
"No partition graphs loaded for device %s" % device_name)
+ debug_graph = self._debug_graphs[device_name]
# Verify that the node names in the dump data are all present in the
# partition graphs.
for datum in self._dump_tensor_data[device_name]:
- if datum.node_name not in self._node_inputs[device_name]:
+ if datum.node_name not in debug_graph.node_inputs:
raise ValueError("Node name '%s' is not found in partition graphs of "
"device %s." % (datum.node_name, device_name))
pending_inputs = {}
- for node in self._node_inputs[device_name]:
+ for node in debug_graph.node_inputs:
pending_inputs[node] = []
- inputs = self._node_inputs[device_name][node]
+ inputs = debug_graph.node_inputs[node]
for inp in inputs:
- inp_node = get_node_name(inp)
- inp_output_slot = get_output_slot(inp)
+ inp_node = debug_graphs.get_node_name(inp)
+ inp_output_slot = debug_graphs.get_output_slot(inp)
# Inputs from Enter and NextIteration nodes are not validated because
# DebugNodeInserter::InsertNodes() in the debugger core skips creating
# control edges from debug ops watching these types of nodes.
if (inp_node in self._debug_watches[device_name] and
inp_output_slot in self._debug_watches[device_name][inp_node] and
- self._node_op_types[device_name].get(inp) not in (
+ debug_graph.node_op_types.get(inp) not in (
"Enter", "NextIteration") and
(inp_node, inp_output_slot) not in pending_inputs[node]):
pending_inputs[node].append((inp_node, inp_output_slot))
@@ -1240,7 +873,7 @@ class DebugDumpDir(object):
"these input(s) are not satisfied: %s" %
(node, datum.timestamp, repr(pending_inputs[node])))
- recipients = self._node_recipients[device_name][node]
+ recipients = debug_graph.node_recipients[node]
for recipient in recipients:
recipient_pending_inputs = pending_inputs[recipient]
if (node, slot) in recipient_pending_inputs:
@@ -1285,7 +918,7 @@ class DebugDumpDir(object):
def loaded_partition_graphs(self):
"""Test whether partition graphs have been loaded."""
- return self._partition_graphs is not None
+ return bool(self._debug_graphs)
def partition_graphs(self):
"""Get the partition graphs.
@@ -1296,11 +929,10 @@ class DebugDumpDir(object):
Raises:
LookupError: If no partition graphs have been loaded.
"""
-
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError("No partition graphs have been loaded.")
-
- return self._partition_graphs.values()
+ return [self._debug_graphs[key].debug_graph_def
+ for key in self._debug_graphs]
@property
def run_fetches_info(self):
@@ -1380,17 +1012,17 @@ class DebugDumpDir(object):
LookupError: If no partition graphs have been loaded.
ValueError: If specified node name does not exist.
"""
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError("No partition graphs have been loaded.")
if device_name is None:
nodes = []
- for device_name in self._node_inputs:
- nodes.extend(self._node_inputs[device_name].keys())
+ for device_name in self._debug_graphs:
+ nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
return nodes
else:
- if device_name not in self._node_inputs:
+ if device_name not in self._debug_graphs:
raise ValueError("Invalid device name: %s" % device_name)
- return self._node_inputs[device_name].keys()
+ return self._debug_graphs[device_name].node_inputs.keys()
def node_attributes(self, node_name, device_name=None):
"""Get the attributes of a node.
@@ -1406,11 +1038,11 @@ class DebugDumpDir(object):
Raises:
LookupError: If no partition graphs have been loaded.
"""
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError("No partition graphs have been loaded.")
device_name = self._infer_device_name(device_name, node_name)
- return self._node_attributes[device_name][node_name]
+ return self._debug_graphs[device_name].node_attributes[node_name]
def node_inputs(self, node_name, is_control=False, device_name=None):
"""Get the inputs of given node according to partition graphs.
@@ -1429,16 +1061,15 @@ class DebugDumpDir(object):
LookupError: If node inputs and control inputs have not been loaded
from partition graphs yet.
"""
-
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError(
"Node inputs are not loaded from partition graphs yet.")
device_name = self._infer_device_name(device_name, node_name)
if is_control:
- return self._node_ctrl_inputs[device_name][node_name]
+ return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
else:
- return self._node_inputs[device_name][node_name]
+ return self._debug_graphs[device_name].node_inputs[node_name]
def transitive_inputs(self,
node_name,
@@ -1466,19 +1097,19 @@ class DebugDumpDir(object):
LookupError: If node inputs and control inputs have not been loaded
from partition graphs yet.
"""
-
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError(
"Node inputs are not loaded from partition graphs yet.")
device_name = self._infer_device_name(device_name, node_name)
- input_lists = [self._node_inputs[device_name]]
+ input_lists = [self._debug_graphs[device_name].node_inputs]
if include_control:
- input_lists.append(self._node_ctrl_inputs[device_name])
+ input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
if include_reversed_ref:
- input_lists.append(self._node_reversed_ref_inputs[device_name])
- tracer = _DFSGraphTracer(
+ input_lists.append(
+ self._debug_graphs[device_name].node_reversed_ref_inputs)
+ tracer = debug_graphs.DFSGraphTracer(
input_lists,
skip_node_names=self._get_merge_node_names(device_name))
tracer.trace(node_name)
@@ -1492,9 +1123,10 @@ class DebugDumpDir(object):
if not hasattr(self, "_merge_node_names"):
self._merge_node_names = {}
if device_name not in self._merge_node_names:
+ debug_graph = self._debug_graphs[device_name]
self._merge_node_names[device_name] = [
- node for node in self._node_op_types[device_name]
- if self._node_op_types[device_name][node] == "Merge"]
+ node for node in debug_graph.node_op_types
+ if debug_graph.node_op_types[node] == "Merge"]
return self._merge_node_names[device_name]
def find_some_path(self,
@@ -1546,12 +1178,13 @@ class DebugDumpDir(object):
"%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
dst_device_name))
- input_lists = [self._node_inputs[dst_device_name]]
+ input_lists = [self._debug_graphs[dst_device_name].node_inputs]
+ debug_graph = self._debug_graphs[dst_device_name]
if include_control:
- input_lists.append(self._node_ctrl_inputs[dst_device_name])
+ input_lists.append(debug_graph.node_ctrl_inputs)
if include_reversed_ref:
- input_lists.append(self._node_reversed_ref_inputs[dst_device_name])
- tracer = _DFSGraphTracer(
+ input_lists.append(debug_graph.node_reversed_ref_inputs)
+ tracer = debug_graphs.DFSGraphTracer(
input_lists,
skip_node_names=self._get_merge_node_names(dst_device_name),
destination_node_name=src_node_name)
@@ -1561,7 +1194,7 @@ class DebugDumpDir(object):
try:
tracer.trace(dst_node_name)
- except _GraphTracingReachedDestination:
+ except debug_graphs.GraphTracingReachedDestination:
# Prune nodes not on the path.
inputs = [dst_node_name] + tracer.inputs()
depth_list = [0] + tracer.depth_list()
@@ -1592,15 +1225,16 @@ class DebugDumpDir(object):
from partition graphs yet.
"""
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError(
"Node recipients are not loaded from partition graphs yet.")
device_name = self._infer_device_name(device_name, node_name)
+ debug_graph = self._debug_graphs[device_name]
if is_control:
- return self._node_ctrl_recipients[device_name][node_name]
+ return debug_graph.node_ctrl_recipients[node_name]
else:
- return self._node_recipients[device_name][node_name]
+ return debug_graph.node_recipients[node_name]
def devices(self):
"""Get the list of device names.
@@ -1608,7 +1242,6 @@ class DebugDumpDir(object):
Returns:
(`list` of `str`) names of the devices.
"""
-
return self._device_names
def node_exists(self, node_name, device_name=None):
@@ -1627,20 +1260,18 @@ class DebugDumpDir(object):
LookupError: If no partition graphs have been loaded yet.
ValueError: If device_name is specified but cannot be found.
"""
-
- if self._node_inputs is None:
+ if not self._debug_graphs:
raise LookupError(
"Nodes have not been loaded from partition graphs yet.")
- if (device_name is not None) and device_name not in self._node_inputs:
+ if (device_name is not None) and device_name not in self._debug_graphs:
raise ValueError(
"The specified device_name '%s' cannot be found." % device_name)
- node_inputs_all_devices = (self._node_inputs if device_name is None
- else (self._node_inputs[device_name],))
-
- return any(node_name in node_inputs_all_devices[dev_name]
- for dev_name in node_inputs_all_devices)
+ for _, debug_graph in self._debug_graphs.items():
+ if node_name in debug_graph.node_inputs:
+ return True
+ return False
def node_device(self, node_name):
"""Get the names of the devices that has nodes of the specified name.
@@ -1658,8 +1289,7 @@ class DebugDumpDir(object):
from partition graphs yet.
ValueError: If the node does not exist in partition graphs.
"""
-
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError(
"Node devices are not loaded from partition graphs yet.")
@@ -1685,13 +1315,12 @@ class DebugDumpDir(object):
LookupError: If node op types have not been loaded
from partition graphs yet.
"""
-
- if self._partition_graphs is None:
+ if not self._debug_graphs:
raise LookupError(
"Node op types are not loaded from partition graphs yet.")
device_name = self._infer_device_name(device_name, node_name)
- return self._node_op_types[device_name][node_name]
+ return self._debug_graphs[device_name].node_op_types[node_name]
def debug_watch_keys(self, node_name, device_name=None):
"""Get all tensor watch keys of given node according to partition graphs.
@@ -1957,7 +1586,7 @@ class DebugDumpDir(object):
if self._python_graph is None:
raise LookupError("Python graph is not available for traceback lookup")
- node_name = get_node_name(element_name)
+ node_name = debug_graphs.get_node_name(element_name)
if node_name not in self._node_traceback:
raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
diff --git a/tensorflow/python/debug/lib/debug_data_test.py b/tensorflow/python/debug/lib/debug_data_test.py
index 694010a23c..7ce7ef6a97 100644
--- a/tensorflow/python/debug/lib/debug_data_test.py
+++ b/tensorflow/python/debug/lib/debug_data_test.py
@@ -49,77 +49,6 @@ class DeviceNamePathConversionTest(test_util.TensorFlowTestCase):
",job_ps,replica_1,task_2,cpu_0"))
-class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
-
- def testParseNodeName(self):
- node_name, slot = debug_data.parse_node_or_tensor_name("namespace1/node_1")
-
- self.assertEqual("namespace1/node_1", node_name)
- self.assertIsNone(slot)
-
- def testParseTensorName(self):
- node_name, slot = debug_data.parse_node_or_tensor_name(
- "namespace1/node_2:3")
-
- self.assertEqual("namespace1/node_2", node_name)
- self.assertEqual(3, slot)
-
-
-class NodeNameChecksTest(test_util.TensorFlowTestCase):
-
- def testIsCopyNode(self):
- self.assertTrue(debug_data.is_copy_node("__copy_ns1/ns2/node3_0"))
-
- self.assertFalse(debug_data.is_copy_node("copy_ns1/ns2/node3_0"))
- self.assertFalse(debug_data.is_copy_node("_copy_ns1/ns2/node3_0"))
- self.assertFalse(debug_data.is_copy_node("_copyns1/ns2/node3_0"))
- self.assertFalse(debug_data.is_copy_node("__dbg_ns1/ns2/node3_0"))
-
- def testIsDebugNode(self):
- self.assertTrue(
- debug_data.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity"))
-
- self.assertFalse(
- debug_data.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity"))
- self.assertFalse(
- debug_data.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity"))
- self.assertFalse(
- debug_data.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity"))
- self.assertFalse(debug_data.is_debug_node("__copy_ns1/ns2/node3_0"))
-
-
-class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
-
- def testParseDebugNodeName_valid(self):
- debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
- (watched_node, watched_output_slot, debug_op_index,
- debug_op) = debug_data.parse_debug_node_name(debug_node_name_1)
-
- self.assertEqual("ns_a/ns_b/node_c", watched_node)
- self.assertEqual(1, watched_output_slot)
- self.assertEqual(0, debug_op_index)
- self.assertEqual("DebugIdentity", debug_op)
-
- def testParseDebugNodeName_invalidPrefix(self):
- invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
-
- with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
- debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
- def testParseDebugNodeName_missingDebugOpIndex(self):
- invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
-
- with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
- debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
- def testParseDebugNodeName_invalidWatchedTensorName(self):
- invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
-
- with self.assertRaisesRegexp(ValueError,
- "Invalid tensor name in debug node name"):
- debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
-
class HasNanOrInfTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -375,19 +304,5 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
fake.assert_has_calls(expected_calls, any_order=True)
-class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase):
-
- def testParseTensorNameInputWorks(self):
- self.assertEqual("a", debug_data.get_node_name("a:0"))
- self.assertEqual(0, debug_data.get_output_slot("a:0"))
-
- self.assertEqual("_b", debug_data.get_node_name("_b:1"))
- self.assertEqual(1, debug_data.get_output_slot("_b:1"))
-
- def testParseNodeNameInputWorks(self):
- self.assertEqual("a", debug_data.get_node_name("a"))
- self.assertEqual(0, debug_data.get_output_slot("a"))
-
-
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py
index 5306391613..b01a58719c 100644
--- a/tensorflow/python/debug/lib/debug_gradients.py
+++ b/tensorflow/python/debug/lib/debug_gradients.py
@@ -24,6 +24,7 @@ import uuid
import six
from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import variables
@@ -34,7 +35,7 @@ _gradient_debuggers = {}
def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid):
- op_name, slot = debug_data.parse_node_or_tensor_name(tensor.name)
+ op_name, slot = debug_graphs.parse_node_or_tensor_name(tensor.name)
return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid)
@@ -407,7 +408,7 @@ def gradient_values_from_dump(grad_debugger, x_tensor, dump):
(grad_debugger.graph, dump.python_graph))
gradient_tensor = grad_debugger.gradient_tensor(x_tensor)
- node_name, output_slot = debug_data.parse_node_or_tensor_name(
+ node_name, output_slot = debug_graphs.parse_node_or_tensor_name(
gradient_tensor.name)
try:
diff --git a/tensorflow/python/debug/lib/debug_graphs.py b/tensorflow/python/debug/lib/debug_graphs.py
new file mode 100644
index 0000000000..20e2a6acfe
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_graphs.py
@@ -0,0 +1,430 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes and methods for processing debugger-decorated graphs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import op_def_registry
+
+
+def parse_node_or_tensor_name(name):
+ """Get the node name from a string that can be node or tensor name.
+
+ Args:
+ name: An input node name (e.g., "node_a") or tensor name (e.g.,
+ "node_a:0"), as a str.
+
+ Returns:
+ 1) The node name, as a str. If the input name is a tensor name, i.e.,
+ consists of a colon, the final colon and the following output slot
+ will be stripped.
+ 2) If the input name is a tensor name, the output slot, as an int. If
+ the input name is not a tensor name, None.
+ """
+
+ if ":" in name and not name.endswith(":"):
+ node_name = name[:name.rfind(":")]
+ output_slot = int(name[name.rfind(":") + 1:])
+
+ return node_name, output_slot
+ else:
+ return name, None
+
+
+def get_node_name(element_name):
+ node_name, _ = parse_node_or_tensor_name(element_name)
+ return node_name
+
+
+def get_output_slot(element_name):
+ """Get the output slot number from the name of a graph element.
+
+ If element_name is a node name without output slot at the end, 0 will be
+ assumed.
+
+ Args:
+ element_name: (`str`) name of the graph element in question.
+
+ Returns:
+ (`int`) output slot number.
+ """
+ _, output_slot = parse_node_or_tensor_name(element_name)
+ return output_slot if output_slot is not None else 0
+
+
+def is_copy_node(node_name):
+ """Determine whether a node name is that of a debug Copy node.
+
+ Such nodes are inserted by TensorFlow core upon request in
+ RunOptions.debug_options.debug_tensor_watch_opts.
+
+ Args:
+ node_name: Name of the node.
+
+ Returns:
+ A bool indicating whether the input argument is the name of a debug Copy
+ node.
+ """
+ return node_name.startswith("__copy_")
+
+
+def is_debug_node(node_name):
+ """Determine whether a node name is that of a debug node.
+
+ Such nodes are inserted by TensorFlow core upon request in
+ RunOptions.debug_options.debug_tensor_watch_opts.
+
+ Args:
+ node_name: Name of the node.
+
+ Returns:
+ A bool indicating whether the input argument is the name of a debug node.
+ """
+ return node_name.startswith("__dbg_")
+
+
+def parse_debug_node_name(node_name):
+ """Parse the name of a debug node.
+
+ Args:
+ node_name: Name of the debug node.
+
+ Returns:
+ 1. Name of the watched node, as a str.
+ 2. Output slot index of the watched tensor, as an int.
+ 3. Index of the debug node, as an int.
+ 4. Name of the debug op, as a str, e.g, "DebugIdentity".
+
+ Raises:
+ ValueError: If the input node name is not a valid debug node name.
+ """
+ prefix = "__dbg_"
+
+ name = node_name
+ if not name.startswith(prefix):
+ raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
+
+ name = name[len(prefix):]
+
+ if name.count("_") < 2:
+ raise ValueError("Invalid debug node name: '%s'" % node_name)
+
+ debug_op = name[name.rindex("_") + 1:]
+ name = name[:name.rindex("_")]
+
+ debug_op_index = int(name[name.rindex("_") + 1:])
+ name = name[:name.rindex("_")]
+
+ if name.count(":") != 1:
+ raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
+
+ watched_node_name = name[:name.index(":")]
+ watched_output_slot = int(name[name.index(":") + 1:])
+
+ return watched_node_name, watched_output_slot, debug_op_index, debug_op
+
+
+class GraphTracingReachedDestination(Exception):
+ pass
+
+
+class DFSGraphTracer(object):
+ """Graph input tracer using depth-first search."""
+
+ def __init__(self,
+ input_lists,
+ skip_node_names=None,
+ destination_node_name=None):
+ """Constructor of _DFSGraphTracer.
+
+ Args:
+ input_lists: A list of dicts. Each dict is an adjacency (input) map from
+ the recipient node name as the key and the list of input node names
+ as the value.
+ skip_node_names: Optional: a list of node names to skip tracing.
+ destination_node_name: Optional: destination node name. If not `None`, it
+ should be the name of a destination not as a str and the graph tracing
+ will raise GraphTracingReachedDestination as soon as the node has been
+ reached.
+
+ Raises:
+ GraphTracingReachedDestination: if stop_at_node_name is not None and
+ the specified node is reached.
+ """
+
+ self._input_lists = input_lists
+ self._skip_node_names = skip_node_names
+
+ self._inputs = []
+ self._visited_nodes = []
+ self._depth_count = 0
+ self._depth_list = []
+
+ self._destination_node_name = destination_node_name
+
+ def trace(self, graph_element_name):
+ """Trace inputs.
+
+ Args:
+ graph_element_name: Name of the node or an output tensor of the node, as a
+ str.
+
+ Raises:
+ GraphTracingReachedDestination: if destination_node_name of this tracer
+ object is not None and the specified node is reached.
+ """
+ self._depth_count += 1
+
+ node_name = get_node_name(graph_element_name)
+ if node_name == self._destination_node_name:
+ raise GraphTracingReachedDestination()
+
+ if node_name in self._skip_node_names:
+ return
+ if node_name in self._visited_nodes:
+ return
+
+ self._visited_nodes.append(node_name)
+
+ for input_list in self._input_lists:
+ for inp in input_list[node_name]:
+ if get_node_name(inp) in self._visited_nodes:
+ continue
+ self._inputs.append(inp)
+ self._depth_list.append(self._depth_count)
+ self.trace(inp)
+
+ self._depth_count -= 1
+
+ def inputs(self):
+ return self._inputs
+
+ def depth_list(self):
+ return self._depth_list
+
+
+class DebugGraph(object):
+ """Represents a debugger-decorated graph."""
+
+ def __init__(self, debug_graph_def, device_name=None):
+ self._debug_graph_def = debug_graph_def
+
+ self._node_attributes = {}
+ self._node_inputs = {}
+ self._node_reversed_ref_inputs = {}
+ self._node_ctrl_inputs = {}
+ self._node_recipients = {}
+ self._node_ctrl_recipients = {}
+ self._node_devices = {}
+ self._node_op_types = {}
+ self._copy_send_nodes = []
+ self._ref_args = {}
+
+ self._device_name = device_name
+ if not self._device_name and debug_graph_def.node:
+ self._device_name = debug_graph_def.node[0].device
+
+ for node in debug_graph_def.node:
+ self._process_debug_graph_node(node)
+
+ self._prune_non_control_edges_of_debug_ops()
+ self._prune_control_edges_of_debug_ops()
+
+ self._populate_recipient_maps()
+
+ def _process_debug_graph_node(self, node):
+ """Process a node from the debug GraphDef.
+
+ Args:
+ node: (NodeDef) A partition-graph node to be processed.
+
+ Raises:
+ ValueError: If duplicate node names are encountered.
+ """
+
+ if is_debug_node(node.name):
+ # This is a debug node. Parse the node name and retrieve the
+ # information about debug watches on tensors. But do not include
+ # the node in the graph.
+ return
+
+ if node.name in self._node_inputs:
+ raise ValueError("Duplicate node name on device %s: '%s'" %
+ (self._device_name, node.name))
+
+ self._node_attributes[node.name] = node.attr
+
+ self._node_inputs[node.name] = []
+ self._node_ctrl_inputs[node.name] = []
+ self._node_recipients[node.name] = []
+ self._node_ctrl_recipients[node.name] = []
+
+ if node.name not in self._node_devices:
+ self._node_devices[node.name] = set()
+ self._node_devices[node.name].add(node.device)
+ self._node_op_types[node.name] = node.op
+ self._ref_args[node.name] = self._get_ref_args(node)
+
+ for inp in node.input:
+ if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
+ self._copy_send_nodes.append(node.name)
+
+ if inp.startswith("^"):
+ cinp = inp[1:]
+ self._node_ctrl_inputs[node.name].append(cinp)
+ else:
+ self._node_inputs[node.name].append(inp)
+
+ def _get_ref_args(self, node):
+ """Determine whether an input of an op is ref-type.
+
+ Args:
+ node: A `NodeDef`.
+
+ Returns:
+ A list of the arg names (as strs) that are ref-type.
+ """
+ op_def = op_def_registry.get_registered_ops().get(node.op)
+ ref_args = []
+ if op_def:
+ for i, output_arg in enumerate(op_def.output_arg):
+ if output_arg.is_ref:
+ arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
+ ref_args.append(arg_name)
+ return ref_args
+
+ def _prune_non_control_edges_of_debug_ops(self):
+ """Prune (non-control) edges related to debug ops.
+
+ Prune the Copy ops and associated _Send ops inserted by the debugger out
+ from the non-control inputs and output recipients map. Replace the inputs
+ and recipients with original ones.
+ """
+ copy_nodes = []
+ for node in self._node_inputs:
+ if node in self._copy_send_nodes:
+ continue
+
+ if is_copy_node(node):
+ copy_nodes.append(node)
+
+ inputs = self._node_inputs[node]
+
+ for i in xrange(len(inputs)):
+ inp = inputs[i]
+ if is_copy_node(inp):
+ # Find the input to the Copy node, which should be the original
+ # input to the node.
+ orig_inp = self._node_inputs[inp][0]
+ inputs[i] = orig_inp
+
+ self._prune_nodes_from_input_and_recipient_maps(copy_nodes)
+ self._prune_nodes_from_input_and_recipient_maps(self._copy_send_nodes)
+
+ def _prune_control_edges_of_debug_ops(self):
+ """Prune control edges related to the debug ops."""
+ for node in self._node_ctrl_inputs:
+ ctrl_inputs = self._node_ctrl_inputs[node]
+ debug_op_inputs = []
+ for ctrl_inp in ctrl_inputs:
+ if is_debug_node(ctrl_inp):
+ debug_op_inputs.append(ctrl_inp)
+ for debug_op_inp in debug_op_inputs:
+ ctrl_inputs.remove(debug_op_inp)
+
+ def _populate_recipient_maps(self):
+ """Populate the map from node name to recipient(s) of its output(s).
+
+ This method also populates the input map based on reversed ref edges.
+ """
+ for node in self._node_inputs:
+ inputs = self._node_inputs[node]
+ for inp in inputs:
+ inp = get_node_name(inp)
+ if inp not in self._node_recipients:
+ self._node_recipients[inp] = []
+ self._node_recipients[inp].append(node)
+
+ if inp in self._ref_args:
+ if inp not in self._node_reversed_ref_inputs:
+ self._node_reversed_ref_inputs[inp] = []
+ self._node_reversed_ref_inputs[inp].append(node)
+
+ for node in self._node_ctrl_inputs:
+ ctrl_inputs = self._node_ctrl_inputs[node]
+ for ctrl_inp in ctrl_inputs:
+ if ctrl_inp in self._copy_send_nodes:
+ continue
+
+ if ctrl_inp not in self._node_ctrl_recipients:
+ self._node_ctrl_recipients[ctrl_inp] = []
+ self._node_ctrl_recipients[ctrl_inp].append(node)
+
+ def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
+ """Prune nodes out of input and recipient maps.
+
+ Args:
+ nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
+ """
+ for node in nodes_to_prune:
+ del self._node_inputs[node]
+ del self._node_ctrl_inputs[node]
+ del self._node_recipients[node]
+ del self._node_ctrl_recipients[node]
+
+ @property
+ def device_name(self):
+ return self._device_name
+
+ @property
+ def debug_graph_def(self):
+ """The debugger-decorated GraphDef."""
+ return self._debug_graph_def
+
+ @property
+ def node_devices(self):
+ return self._node_devices
+
+ @property
+ def node_op_types(self):
+ return self._node_op_types
+
+ @property
+ def node_attributes(self):
+ return self._node_attributes
+
+ @property
+ def node_inputs(self):
+ return self._node_inputs
+
+ @property
+ def node_ctrl_inputs(self):
+ return self._node_ctrl_inputs
+
+ @property
+ def node_reversed_ref_inputs(self):
+ return self._node_reversed_ref_inputs
+
+ @property
+ def node_recipients(self):
+ return self._node_recipients
+
+ @property
+ def node_ctrl_recipients(self):
+ return self._node_ctrl_recipients
diff --git a/tensorflow/python/debug/lib/debug_graphs_test.py b/tensorflow/python/debug/lib/debug_graphs_test.py
new file mode 100644
index 0000000000..34257794f1
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_graphs_test.py
@@ -0,0 +1,112 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tfdbg module debug_data."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.debug.lib import debug_graphs
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
+
+ def testParseNodeName(self):
+ node_name, slot = debug_graphs.parse_node_or_tensor_name(
+ "namespace1/node_1")
+
+ self.assertEqual("namespace1/node_1", node_name)
+ self.assertIsNone(slot)
+
+ def testParseTensorName(self):
+ node_name, slot = debug_graphs.parse_node_or_tensor_name(
+ "namespace1/node_2:3")
+
+ self.assertEqual("namespace1/node_2", node_name)
+ self.assertEqual(3, slot)
+
+
+class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase):
+
+ def testParseTensorNameInputWorks(self):
+ self.assertEqual("a", debug_graphs.get_node_name("a:0"))
+ self.assertEqual(0, debug_graphs.get_output_slot("a:0"))
+
+ self.assertEqual("_b", debug_graphs.get_node_name("_b:1"))
+ self.assertEqual(1, debug_graphs.get_output_slot("_b:1"))
+
+ def testParseNodeNameInputWorks(self):
+ self.assertEqual("a", debug_graphs.get_node_name("a"))
+ self.assertEqual(0, debug_graphs.get_output_slot("a"))
+
+
+class NodeNameChecksTest(test_util.TensorFlowTestCase):
+
+ def testIsCopyNode(self):
+ self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0"))
+
+ self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0"))
+ self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0"))
+ self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0"))
+ self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0"))
+
+ def testIsDebugNode(self):
+ self.assertTrue(
+ debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+
+ self.assertFalse(
+ debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+ self.assertFalse(
+ debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+ self.assertFalse(
+ debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity"))
+ self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0"))
+
+
+class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
+
+ def testParseDebugNodeName_valid(self):
+ debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
+ (watched_node, watched_output_slot, debug_op_index,
+ debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1)
+
+ self.assertEqual("ns_a/ns_b/node_c", watched_node)
+ self.assertEqual(1, watched_output_slot)
+ self.assertEqual(0, debug_op_index)
+ self.assertEqual("DebugIdentity", debug_op)
+
+ def testParseDebugNodeName_invalidPrefix(self):
+ invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
+
+ with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
+ debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+ def testParseDebugNodeName_missingDebugOpIndex(self):
+ invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
+
+ with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
+ debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+ def testParseDebugNodeName_invalidWatchedTensorName(self):
+ invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
+
+ with self.assertRaisesRegexp(ValueError,
+ "Invalid tensor name in debug node name"):
+ debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/debug/lib/grpc_debug_server.py b/tensorflow/python/debug/lib/grpc_debug_server.py
index 309fdb3bce..28cf1514f6 100644
--- a/tensorflow/python/debug/lib/grpc_debug_server.py
+++ b/tensorflow/python/debug/lib/grpc_debug_server.py
@@ -29,7 +29,7 @@ from six.moves import queue
from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.framework import graph_pb2
-from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.platform import tf_logging as logging
@@ -294,10 +294,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
def _process_graph_def(self, graph_def):
for node_def in graph_def.node:
- if (debug_data.is_debug_node(node_def.name) and
+ if (debug_graphs.is_debug_node(node_def.name) and
node_def.attr["gated_grpc"].b):
node_name, output_slot, _, debug_op = (
- debug_data.parse_debug_node_name(node_def.name))
+ debug_graphs.parse_debug_node_name(node_def.name))
self._gated_grpc_debug_watches.add(
DebugWatch(node_name, output_slot, debug_op))
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index 48f31771db..aa5314dda5 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
-class SessionDebugTest(session_debug_testlib.SessionDebugTestBase):
+class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 08b3e75e7c..d4b9d06b54 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -33,6 +33,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -242,7 +243,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
v_copy_node_def = None
for partition_graph in run_metadata.partition_graphs:
for node_def in partition_graph.node:
- if debug_data.is_copy_node(node_def.name):
+ if debug_graphs.is_copy_node(node_def.name):
if node_def.name == "__copy_u_0":
u_copy_node_def = node_def
elif node_def.name == "__copy_v_0":
diff --git a/tensorflow/python/debug/lib/stepper.py b/tensorflow/python/debug/lib/stepper.py
index c814520b7e..1fa0b3dba2 100644
--- a/tensorflow/python/debug/lib/stepper.py
+++ b/tensorflow/python/debug/lib/stepper.py
@@ -27,6 +27,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import session_ops
@@ -706,8 +707,8 @@ class NodeStepper(object):
if ":" in element_name:
debug_utils.add_debug_tensor_watch(
run_options,
- debug_data.get_node_name(element_name),
- output_slot=debug_data.get_output_slot(element_name),
+ debug_graphs.get_node_name(element_name),
+ output_slot=debug_graphs.get_output_slot(element_name),
debug_urls=["file://" + dump_path])
return dump_path, run_options
@@ -961,5 +962,5 @@ class NodeStepper(object):
The node associated with element in the graph.
"""
- node_name, _ = debug_data.parse_node_or_tensor_name(element.name)
+ node_name, _ = debug_graphs.parse_node_or_tensor_name(element.name)
return self._sess.graph.as_graph_element(node_name)