diff options
author | 2017-09-06 08:46:57 -0700 | |
---|---|---|
committer | 2017-09-06 08:51:00 -0700 | |
commit | fa0a40a51f6f9cce48342d40c19e184470148242 (patch) | |
tree | 469b86a68f6e054d754637bdae1a05b4db7c86a2 /tensorflow/python/debug | |
parent | 0365fc5606ec36c0a559e50d1d65890de7a76ddd (diff) |
tfdbg: Refactor graph-processing code out of debug_data.py
The basic idea is to separate the code in debug_data.py that handles graph structures into its own module (debug_graphs.py). This tackles an existing TODO item to simplify the code debug_data.DebugDumpDir.
In a later CL, code will be added to debug_graphs.DebugGraph to allow reconstruction of the original GraphDef, i.e., the GraphDef without the Copy* and Debug* nodes inserted by tfdbg. This will be useful for, among other things, the TensorBoard Debugger Plugin.
PiperOrigin-RevId: 167726113
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/BUILD | 34 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/analyzer_cli.py | 16 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_data.py | 543 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_data_test.py | 85 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_gradients.py | 5 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_graphs.py | 430 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_graphs_test.py | 112 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/grpc_debug_server.py | 6 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_file_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_testlib.py | 3 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/stepper.py | 7 |
11 files changed, 681 insertions, 562 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 8eb2212069..c092616999 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -50,10 +50,25 @@ py_library( ) py_library( + name = "debug_graphs", + srcs = ["lib/debug_graphs.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python:op_def_registry", + "//tensorflow/python:platform", + "//tensorflow/python:tensor_util", + "@six_archive//:six", + ], +) + +py_library( name = "debug_data", srcs = ["lib/debug_data.py"], srcs_version = "PY2AND3", deps = [ + ":debug_graphs", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework", "//tensorflow/python:op_def_registry", @@ -70,6 +85,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":debug_data", + ":debug_graphs", "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform", @@ -99,6 +115,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":debug_data", + ":debug_graphs", ":debug_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", @@ -181,7 +198,7 @@ py_library( deps = [ ":cli_shared", ":command_parser", - ":debug_data", + ":debug_graphs", ":debugger_cli_common", ":evaluator", ":source_utils", @@ -401,6 +418,18 @@ py_binary( ) py_test( + name = "debug_graphs_test", + size = "small", + srcs = ["lib/debug_graphs_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_graphs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "debug_data_test", size = "small", srcs = ["lib/debug_data_test.py"], @@ -569,6 +598,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":debug_data", + ":debug_graphs", ":debug_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -608,7 +638,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":debug_data", + ":debug_graphs", ":debug_service_pb2_grpc", "//tensorflow/core/debug:debug_service_proto_py", "@six_archive//:six", diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py index 22e451e38c..50850bbc0d 100644 --- a/tensorflow/python/debug/cli/analyzer_cli.py +++ b/tensorflow/python/debug/cli/analyzer_cli.py @@ -34,7 +34,7 @@ from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import evaluator from tensorflow.python.debug.cli import ui_factory -from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.debug.lib import source_utils RL = debugger_cli_common.RichLine @@ -716,7 +716,7 @@ class DebugAnalyzer(object): # Get a node name, regardless of whether the input is a node name (without # output slot attached) or a tensor name (with output slot attached). - node_name, unused_slot = debug_data.parse_node_or_tensor_name( + node_name, unused_slot = debug_graphs.parse_node_or_tensor_name( parsed.node_name) if not self._debug_dump.node_exists(node_name): @@ -840,7 +840,7 @@ class DebugAnalyzer(object): parsed.op_type, do_outputs=False) - node_name = debug_data.get_node_name(parsed.node_name) + node_name = debug_graphs.get_node_name(parsed.node_name) _add_main_menu(output, node_name=node_name, enable_list_inputs=False) return output @@ -871,7 +871,7 @@ class DebugAnalyzer(object): tensor_name, tensor_slicing = ( command_parser.parse_tensor_name_with_slicing(parsed.tensor_name)) - node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name) + node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name) if (self._debug_dump.loaded_partition_graphs() and not self._debug_dump.node_exists(node_name)): output = cli_shared.error( @@ -1016,7 +1016,7 @@ class DebugAnalyzer(object): parsed.op_type, do_outputs=True) - node_name = debug_data.get_node_name(parsed.node_name) + node_name = debug_graphs.get_node_name(parsed.node_name) _add_main_menu(output, node_name=node_name, enable_list_outputs=False) return output @@ -1087,7 +1087,7 @@ class DebugAnalyzer(object): label = RL(" " * 4) if self._debug_dump.debug_watch_keys( - debug_data.get_node_name(element)): + debug_graphs.get_node_name(element)): attribute = debugger_cli_common.MenuItem("", "pt %s" % element) else: attribute = cli_shared.COLOR_BLUE @@ -1246,7 +1246,7 @@ class DebugAnalyzer(object): font_attr_segs = {} # Check if this is a tensor name, instead of a node name. - node_name, _ = debug_data.parse_node_or_tensor_name(node_name) + node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name) # Check if node exists. if not self._debug_dump.node_exists(node_name): @@ -1395,7 +1395,7 @@ class DebugAnalyzer(object): # Recursive call. # The input's/output's name can be a tensor name, in the case of node # with >1 output slots. - inp_node_name, _ = debug_data.parse_node_or_tensor_name(inp) + inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp) self._dfs_from_node( lines, attr_segs, diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index b2b3ec5d47..9ea279c004 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -26,14 +26,14 @@ import platform import numpy as np import six -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.util import event_pb2 -from tensorflow.python.framework import op_def_registry +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.framework import tensor_util from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat @@ -155,30 +155,6 @@ def _load_log_message_from_event_file(event_file_path): return event.log_message.message -def parse_node_or_tensor_name(name): - """Get the node name from a string that can be node or tensor name. - - Args: - name: An input node name (e.g., "node_a") or tensor name (e.g., - "node_a:0"), as a str. - - Returns: - 1) The node name, as a str. If the input name is a tensor name, i.e., - consists of a colon, the final colon and the following output slot - will be stripped. - 2) If the input name is a tensor name, the output slot, as an int. If - the input name is not a tensor name, None. - """ - - if ":" in name and not name.endswith(":"): - node_name = name[:name.rfind(":")] - output_slot = int(name[name.rfind(":") + 1:]) - - return node_name, output_slot - else: - return name, None - - def _is_graph_file(file_name): return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG) @@ -191,25 +167,6 @@ def _is_run_feed_keys_info_file(file_name): return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG -def get_node_name(element_name): - return element_name.split(":")[0] if ":" in element_name else element_name - - -def get_output_slot(element_name): - """Get the output slot number from the name of a graph element. - - If element_name is a node name without output slot at the end, 0 will be - assumed. - - Args: - element_name: (`str`) name of the graph element in question. - - Returns: - (`int`) output slot number. - """ - return int(element_name.split(":")[-1]) if ":" in element_name else 0 - - def _get_tensor_name(node_name, output_slot): """Get tensor name given node name and output slot index. @@ -241,78 +198,6 @@ def _get_tensor_watch_key(node_name, output_slot, debug_op): return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op) -def is_copy_node(node_name): - """Determine whether a node name is that of a debug Copy node. - - Such nodes are inserted by TensorFlow core upon request in - RunOptions.debug_options.debug_tensor_watch_opts. - - Args: - node_name: Name of the node. - - Returns: - A bool indicating whether the input argument is the name of a debug Copy - node. - """ - return node_name.startswith("__copy_") - - -def is_debug_node(node_name): - """Determine whether a node name is that of a debug node. - - Such nodes are inserted by TensorFlow core upon request in - RunOptions.debug_options.debug_tensor_watch_opts. - - Args: - node_name: Name of the node. - - Returns: - A bool indicating whether the input argument is the name of a debug node. - """ - return node_name.startswith("__dbg_") - - -def parse_debug_node_name(node_name): - """Parse the name of a debug node. - - Args: - node_name: Name of the debug node. - - Returns: - 1. Name of the watched node, as a str. - 2. Output slot index of the watched tensor, as an int. - 3. Index of the debug node, as an int. - 4. Name of the debug op, as a str, e.g, "DebugIdentity". - - Raises: - ValueError: If the input node name is not a valid debug node name. - """ - prefix = "__dbg_" - - name = node_name - if not name.startswith(prefix): - raise ValueError("Invalid prefix in debug node name: '%s'" % node_name) - - name = name[len(prefix):] - - if name.count("_") < 2: - raise ValueError("Invalid debug node name: '%s'" % node_name) - - debug_op = name[name.rindex("_") + 1:] - name = name[:name.rindex("_")] - - debug_op_index = int(name[name.rindex("_") + 1:]) - name = name[:name.rindex("_")] - - if name.count(":") != 1: - raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name) - - watched_node_name = name[:name.index(":")] - watched_output_slot = int(name[name.index(":") + 1:]) - - return watched_node_name, watched_output_slot, debug_op_index, debug_op - - def has_inf_or_nan(datum, tensor): """A predicate for whether a tensor consists of any bad numerical values. @@ -573,88 +458,6 @@ class WatchKeyDoesNotExistInDebugDumpDirError(ValueError): pass -class _GraphTracingReachedDestination(Exception): - pass - - -class _DFSGraphTracer(object): - """Graph input tracer using depth-first search.""" - - def __init__(self, - input_lists, - skip_node_names=None, - destination_node_name=None): - """Constructor of _DFSGraphTracer. - - Args: - input_lists: A list of dicts. Each dict is an adjacency (input) map from - the recipient node name as the key and the list of input node names - as the value. - skip_node_names: Optional: a list of node names to skip tracing. - destination_node_name: Optional: destination node name. If not `None`, it - should be the name of a destination not as a str and the graph tracing - will raise GraphTracingReachedDestination as soon as the node has been - reached. - - Raises: - _GraphTracingReachedDestination: if stop_at_node_name is not None and - the specified node is reached. - """ - - self._input_lists = input_lists - self._skip_node_names = skip_node_names - - self._inputs = [] - self._visited_nodes = [] - self._depth_count = 0 - self._depth_list = [] - - self._destination_node_name = destination_node_name - - def trace(self, graph_element_name): - """Trace inputs. - - Args: - graph_element_name: Name of the node or an output tensor of the node, as a - str. - - Raises: - _GraphTracingReachedDestination: if destination_node_name of this tracer - object is not None and the specified node is reached. - """ - self._depth_count += 1 - - node_name = get_node_name(graph_element_name) - - if node_name == self._destination_node_name: - raise _GraphTracingReachedDestination() - - if node_name in self._skip_node_names: - return - if node_name in self._visited_nodes: - return - - self._visited_nodes.append(node_name) - - for input_list in self._input_lists: - for inp in input_list[node_name]: - if get_node_name(inp) in self._visited_nodes: - continue - self._inputs.append(inp) - self._depth_list.append(self._depth_count) - self.trace(inp) - - self._depth_count -= 1 - - def inputs(self): - return self._inputs - - def depth_list(self): - return self._depth_list - - -# TODO(cais): This class is getting too large in line count. Refactor to make it -# smaller and easier to maintain. class DebugDumpDir(object): """Data set from a debug-dump directory on filesystem. @@ -963,52 +766,36 @@ class DebugDumpDir(object): ValueError: If the partition GraphDef of one or more devices fail to be loaded. """ - - self._node_attributes = {} - self._node_inputs = {} - self._node_reversed_ref_inputs = {} - self._node_ctrl_inputs = {} - self._node_recipients = {} - self._node_ctrl_recipients = {} + self._debug_graphs = {} self._node_devices = {} - self._node_op_types = {} - self._copy_send_nodes = {} - self._ref_args = {} - - self._partition_graphs = {} - for device_name in self._device_names: - partition_graph = None - if device_name in self._dump_graph_file_paths: - partition_graph = _load_graph_def_from_event_file( - self._dump_graph_file_paths[device_name]) - else: - partition_graph = self._find_partition_graph(partition_graphs, - device_name) - - if partition_graph: - self._partition_graphs[device_name] = partition_graph - self._node_attributes[device_name] = {} - self._node_inputs[device_name] = {} - self._node_reversed_ref_inputs[device_name] = {} - self._node_ctrl_inputs[device_name] = {} - self._node_recipients[device_name] = {} - self._node_ctrl_recipients[device_name] = {} - self._node_op_types[device_name] = {} - self._copy_send_nodes[device_name] = [] - self._ref_args[device_name] = [] - - if partition_graph: - for node in partition_graph.node: - self._process_partition_graph_node(device_name, node) - - self._prune_non_control_edges_of_debug_ops(device_name) - self._prune_control_edges_of_debug_ops(device_name) + if partition_graphs: + partition_graphs_and_device_names = [ + (partition_graph, None) for partition_graph in partition_graphs] + else: + partition_graphs_and_device_names = [] + for device_name in self._device_names: + partition_graph = None + if device_name in self._dump_graph_file_paths: + partition_graph = _load_graph_def_from_event_file( + self._dump_graph_file_paths[device_name]) + else: + partition_graph = self._find_partition_graph(partition_graphs, + device_name) + if partition_graph: + partition_graphs_and_device_names.append((partition_graph, + device_name)) + else: + logging.warn("Failed to load partition graphs from disk.") - self._populate_recipient_maps(device_name) + for partition_graph, maybe_device_name in partition_graphs_and_device_names: + debug_graph = debug_graphs.DebugGraph(partition_graph, + device_name=maybe_device_name) + self._debug_graphs[debug_graph.device_name] = debug_graph + self._collect_node_devices(debug_graph) - if device_name in self._partition_graphs and validate: - self._validate_dump_with_graphs(device_name) + if validate and debug_graph.device_name in self._dump_tensor_data: + self._validate_dump_with_graphs(debug_graph.device_name) def _find_partition_graph(self, partition_graphs, device_name): if partition_graphs is None: @@ -1020,167 +807,13 @@ class DebugDumpDir(object): return graph_def return None - def _get_ref_args(self, node): - """Determine whether an input of an op is ref-type. - - Args: - node: A `NodeDef`. - - Returns: - A list of the arg names (as strs) that are ref-type. - """ - - op_def = op_def_registry.get_registered_ops().get(node.op) - ref_args = [] - if op_def: - for i, output_arg in enumerate(op_def.output_arg): - if output_arg.is_ref: - arg_name = node.name if i == 0 else (node.name + ":%d" % i) - ref_args.append(arg_name) - return ref_args - - def _process_partition_graph_node(self, device_name, node): - """Process a node from the partition graphs. - - Args: - device_name: (str) device name. - node: (NodeDef) A partition-graph node to be processed. - - Raises: - ValueError: If duplicate node names are encountered. - """ - - if is_debug_node(node.name): - # This is a debug node. Parse the node name and retrieve the - # information about debug watches on tensors. But do not include - # the node in the graph. - (watched_node_name, watched_output_slot, _, - debug_op) = parse_debug_node_name(node.name) - - self._debug_watches[device_name][watched_node_name][ - watched_output_slot].add(debug_op) - - return - - if node.name in self._node_inputs[device_name]: - raise ValueError("Duplicate node name on device %s: '%s'" % - (device_name, node.name)) - - self._node_attributes[device_name][node.name] = node.attr - - self._node_inputs[device_name][node.name] = [] - self._node_ctrl_inputs[device_name][node.name] = [] - self._node_recipients[device_name][node.name] = [] - self._node_ctrl_recipients[device_name][node.name] = [] - - if node.name not in self._node_devices: - self._node_devices[node.name] = set() - self._node_devices[node.name].add(node.device) - self._node_op_types[device_name][node.name] = node.op - self._ref_args[device_name].extend(self._get_ref_args(node)) - - for inp in node.input: - if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): - self._copy_send_nodes[device_name].append(node.name) - - if inp.startswith("^"): - cinp = inp[1:] - self._node_ctrl_inputs[device_name][node.name].append(cinp) + def _collect_node_devices(self, debug_graph): + for node_name in debug_graph.node_devices: + if node_name in self._node_devices: + self._node_devices[node_name] = self._node_devices[node_name].union( + debug_graph.node_devices[node_name]) else: - self._node_inputs[device_name][node.name].append(inp) - - def _prune_nodes_from_input_and_recipient_maps(self, - device_name, - nodes_to_prune): - """Prune nodes out of input and recipient maps. - - Args: - device_name: (`str`) device name. - nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned. - """ - - for node in nodes_to_prune: - del self._node_inputs[device_name][node] - del self._node_ctrl_inputs[device_name][node] - del self._node_recipients[device_name][node] - del self._node_ctrl_recipients[device_name][node] - - def _prune_non_control_edges_of_debug_ops(self, device_name): - """Prune (non-control) edges related to debug ops. - - Prune the Copy ops and associated _Send ops inserted by the debugger out - from the non-control inputs and output recipients map. Replace the inputs - and recipients with original ones. - - Args: - device_name: (`str`) device name. - """ - - copy_nodes = [] - for node in self._node_inputs[device_name]: - if node in self._copy_send_nodes[device_name]: - continue - - if is_copy_node(node): - copy_nodes.append(node) - - inputs = self._node_inputs[device_name][node] - - for i in xrange(len(inputs)): - inp = inputs[i] - if is_copy_node(inp): - # Find the input to the Copy node, which should be the original - # input to the node. - orig_inp = self._node_inputs[device_name][inp][0] - inputs[i] = orig_inp - - self._prune_nodes_from_input_and_recipient_maps(device_name, copy_nodes) - self._prune_nodes_from_input_and_recipient_maps( - device_name, self._copy_send_nodes[device_name]) - - def _prune_control_edges_of_debug_ops(self, device_name): - """Prune control edges related to the debug ops.""" - - for node in self._node_ctrl_inputs[device_name]: - ctrl_inputs = self._node_ctrl_inputs[device_name][node] - debug_op_inputs = [] - for ctrl_inp in ctrl_inputs: - if is_debug_node(ctrl_inp): - debug_op_inputs.append(ctrl_inp) - for debug_op_inp in debug_op_inputs: - ctrl_inputs.remove(debug_op_inp) - - def _populate_recipient_maps(self, device_name): - """Populate the map from node name to recipient(s) of its output(s). - - This method also populates the input map based on reversed ref edges. - - Args: - device_name: name of device. - """ - - for node in self._node_inputs[device_name]: - inputs = self._node_inputs[device_name][node] - for inp in inputs: - inp = get_node_name(inp) - if inp not in self._node_recipients[device_name]: - self._node_recipients[device_name][inp] = [] - self._node_recipients[device_name][inp].append(node) - - if inp in self._ref_args[device_name]: - if inp not in self._node_reversed_ref_inputs[device_name]: - self._node_reversed_ref_inputs[device_name][inp] = [] - self._node_reversed_ref_inputs[device_name][inp].append(node) - - for node in self._node_ctrl_inputs[device_name]: - ctrl_inputs = self._node_ctrl_inputs[device_name][node] - for ctrl_inp in ctrl_inputs: - if ctrl_inp in self._copy_send_nodes[device_name]: - continue - - if ctrl_inp not in self._node_ctrl_recipients[device_name]: - self._node_ctrl_recipients[device_name][ctrl_inp] = [] - self._node_ctrl_recipients[device_name][ctrl_inp].append(node) + self._node_devices[node_name] = debug_graph.node_devices[node_name] def _validate_dump_with_graphs(self, device_name): """Validate the dumped tensor data against the partition graphs. @@ -1197,31 +830,31 @@ class DebugDumpDir(object): Or if the temporal order of the dump's timestamps violate the input relations on the partition graphs. """ - - if not self._partition_graphs[device_name]: + if not self._debug_graphs: raise LookupError( "No partition graphs loaded for device %s" % device_name) + debug_graph = self._debug_graphs[device_name] # Verify that the node names in the dump data are all present in the # partition graphs. for datum in self._dump_tensor_data[device_name]: - if datum.node_name not in self._node_inputs[device_name]: + if datum.node_name not in debug_graph.node_inputs: raise ValueError("Node name '%s' is not found in partition graphs of " "device %s." % (datum.node_name, device_name)) pending_inputs = {} - for node in self._node_inputs[device_name]: + for node in debug_graph.node_inputs: pending_inputs[node] = [] - inputs = self._node_inputs[device_name][node] + inputs = debug_graph.node_inputs[node] for inp in inputs: - inp_node = get_node_name(inp) - inp_output_slot = get_output_slot(inp) + inp_node = debug_graphs.get_node_name(inp) + inp_output_slot = debug_graphs.get_output_slot(inp) # Inputs from Enter and NextIteration nodes are not validated because # DebugNodeInserter::InsertNodes() in the debugger core skips creating # control edges from debug ops watching these types of nodes. if (inp_node in self._debug_watches[device_name] and inp_output_slot in self._debug_watches[device_name][inp_node] and - self._node_op_types[device_name].get(inp) not in ( + debug_graph.node_op_types.get(inp) not in ( "Enter", "NextIteration") and (inp_node, inp_output_slot) not in pending_inputs[node]): pending_inputs[node].append((inp_node, inp_output_slot)) @@ -1240,7 +873,7 @@ class DebugDumpDir(object): "these input(s) are not satisfied: %s" % (node, datum.timestamp, repr(pending_inputs[node]))) - recipients = self._node_recipients[device_name][node] + recipients = debug_graph.node_recipients[node] for recipient in recipients: recipient_pending_inputs = pending_inputs[recipient] if (node, slot) in recipient_pending_inputs: @@ -1285,7 +918,7 @@ class DebugDumpDir(object): def loaded_partition_graphs(self): """Test whether partition graphs have been loaded.""" - return self._partition_graphs is not None + return bool(self._debug_graphs) def partition_graphs(self): """Get the partition graphs. @@ -1296,11 +929,10 @@ class DebugDumpDir(object): Raises: LookupError: If no partition graphs have been loaded. """ - - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError("No partition graphs have been loaded.") - - return self._partition_graphs.values() + return [self._debug_graphs[key].debug_graph_def + for key in self._debug_graphs] @property def run_fetches_info(self): @@ -1380,17 +1012,17 @@ class DebugDumpDir(object): LookupError: If no partition graphs have been loaded. ValueError: If specified node name does not exist. """ - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError("No partition graphs have been loaded.") if device_name is None: nodes = [] - for device_name in self._node_inputs: - nodes.extend(self._node_inputs[device_name].keys()) + for device_name in self._debug_graphs: + nodes.extend(self._debug_graphs[device_name].node_inputs.keys()) return nodes else: - if device_name not in self._node_inputs: + if device_name not in self._debug_graphs: raise ValueError("Invalid device name: %s" % device_name) - return self._node_inputs[device_name].keys() + return self._debug_graphs[device_name].node_inputs.keys() def node_attributes(self, node_name, device_name=None): """Get the attributes of a node. @@ -1406,11 +1038,11 @@ class DebugDumpDir(object): Raises: LookupError: If no partition graphs have been loaded. """ - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError("No partition graphs have been loaded.") device_name = self._infer_device_name(device_name, node_name) - return self._node_attributes[device_name][node_name] + return self._debug_graphs[device_name].node_attributes[node_name] def node_inputs(self, node_name, is_control=False, device_name=None): """Get the inputs of given node according to partition graphs. @@ -1429,16 +1061,15 @@ class DebugDumpDir(object): LookupError: If node inputs and control inputs have not been loaded from partition graphs yet. """ - - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError( "Node inputs are not loaded from partition graphs yet.") device_name = self._infer_device_name(device_name, node_name) if is_control: - return self._node_ctrl_inputs[device_name][node_name] + return self._debug_graphs[device_name].node_ctrl_inputs[node_name] else: - return self._node_inputs[device_name][node_name] + return self._debug_graphs[device_name].node_inputs[node_name] def transitive_inputs(self, node_name, @@ -1466,19 +1097,19 @@ class DebugDumpDir(object): LookupError: If node inputs and control inputs have not been loaded from partition graphs yet. """ - - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError( "Node inputs are not loaded from partition graphs yet.") device_name = self._infer_device_name(device_name, node_name) - input_lists = [self._node_inputs[device_name]] + input_lists = [self._debug_graphs[device_name].node_inputs] if include_control: - input_lists.append(self._node_ctrl_inputs[device_name]) + input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs) if include_reversed_ref: - input_lists.append(self._node_reversed_ref_inputs[device_name]) - tracer = _DFSGraphTracer( + input_lists.append( + self._debug_graphs[device_name].node_reversed_ref_inputs) + tracer = debug_graphs.DFSGraphTracer( input_lists, skip_node_names=self._get_merge_node_names(device_name)) tracer.trace(node_name) @@ -1492,9 +1123,10 @@ class DebugDumpDir(object): if not hasattr(self, "_merge_node_names"): self._merge_node_names = {} if device_name not in self._merge_node_names: + debug_graph = self._debug_graphs[device_name] self._merge_node_names[device_name] = [ - node for node in self._node_op_types[device_name] - if self._node_op_types[device_name][node] == "Merge"] + node for node in debug_graph.node_op_types + if debug_graph.node_op_types[node] == "Merge"] return self._merge_node_names[device_name] def find_some_path(self, @@ -1546,12 +1178,13 @@ class DebugDumpDir(object): "%s vs. %s" % (src_node_name, dst_node_name, src_device_name, dst_device_name)) - input_lists = [self._node_inputs[dst_device_name]] + input_lists = [self._debug_graphs[dst_device_name].node_inputs] + debug_graph = self._debug_graphs[dst_device_name] if include_control: - input_lists.append(self._node_ctrl_inputs[dst_device_name]) + input_lists.append(debug_graph.node_ctrl_inputs) if include_reversed_ref: - input_lists.append(self._node_reversed_ref_inputs[dst_device_name]) - tracer = _DFSGraphTracer( + input_lists.append(debug_graph.node_reversed_ref_inputs) + tracer = debug_graphs.DFSGraphTracer( input_lists, skip_node_names=self._get_merge_node_names(dst_device_name), destination_node_name=src_node_name) @@ -1561,7 +1194,7 @@ class DebugDumpDir(object): try: tracer.trace(dst_node_name) - except _GraphTracingReachedDestination: + except debug_graphs.GraphTracingReachedDestination: # Prune nodes not on the path. inputs = [dst_node_name] + tracer.inputs() depth_list = [0] + tracer.depth_list() @@ -1592,15 +1225,16 @@ class DebugDumpDir(object): from partition graphs yet. """ - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError( "Node recipients are not loaded from partition graphs yet.") device_name = self._infer_device_name(device_name, node_name) + debug_graph = self._debug_graphs[device_name] if is_control: - return self._node_ctrl_recipients[device_name][node_name] + return debug_graph.node_ctrl_recipients[node_name] else: - return self._node_recipients[device_name][node_name] + return debug_graph.node_recipients[node_name] def devices(self): """Get the list of device names. @@ -1608,7 +1242,6 @@ class DebugDumpDir(object): Returns: (`list` of `str`) names of the devices. """ - return self._device_names def node_exists(self, node_name, device_name=None): @@ -1627,20 +1260,18 @@ class DebugDumpDir(object): LookupError: If no partition graphs have been loaded yet. ValueError: If device_name is specified but cannot be found. """ - - if self._node_inputs is None: + if not self._debug_graphs: raise LookupError( "Nodes have not been loaded from partition graphs yet.") - if (device_name is not None) and device_name not in self._node_inputs: + if (device_name is not None) and device_name not in self._debug_graphs: raise ValueError( "The specified device_name '%s' cannot be found." % device_name) - node_inputs_all_devices = (self._node_inputs if device_name is None - else (self._node_inputs[device_name],)) - - return any(node_name in node_inputs_all_devices[dev_name] - for dev_name in node_inputs_all_devices) + for _, debug_graph in self._debug_graphs.items(): + if node_name in debug_graph.node_inputs: + return True + return False def node_device(self, node_name): """Get the names of the devices that has nodes of the specified name. @@ -1658,8 +1289,7 @@ class DebugDumpDir(object): from partition graphs yet. ValueError: If the node does not exist in partition graphs. """ - - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError( "Node devices are not loaded from partition graphs yet.") @@ -1685,13 +1315,12 @@ class DebugDumpDir(object): LookupError: If node op types have not been loaded from partition graphs yet. """ - - if self._partition_graphs is None: + if not self._debug_graphs: raise LookupError( "Node op types are not loaded from partition graphs yet.") device_name = self._infer_device_name(device_name, node_name) - return self._node_op_types[device_name][node_name] + return self._debug_graphs[device_name].node_op_types[node_name] def debug_watch_keys(self, node_name, device_name=None): """Get all tensor watch keys of given node according to partition graphs. @@ -1957,7 +1586,7 @@ class DebugDumpDir(object): if self._python_graph is None: raise LookupError("Python graph is not available for traceback lookup") - node_name = get_node_name(element_name) + node_name = debug_graphs.get_node_name(element_name) if node_name not in self._node_traceback: raise KeyError("Cannot find node \"%s\" in Python graph" % node_name) diff --git a/tensorflow/python/debug/lib/debug_data_test.py b/tensorflow/python/debug/lib/debug_data_test.py index 694010a23c..7ce7ef6a97 100644 --- a/tensorflow/python/debug/lib/debug_data_test.py +++ b/tensorflow/python/debug/lib/debug_data_test.py @@ -49,77 +49,6 @@ class DeviceNamePathConversionTest(test_util.TensorFlowTestCase): ",job_ps,replica_1,task_2,cpu_0")) -class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase): - - def testParseNodeName(self): - node_name, slot = debug_data.parse_node_or_tensor_name("namespace1/node_1") - - self.assertEqual("namespace1/node_1", node_name) - self.assertIsNone(slot) - - def testParseTensorName(self): - node_name, slot = debug_data.parse_node_or_tensor_name( - "namespace1/node_2:3") - - self.assertEqual("namespace1/node_2", node_name) - self.assertEqual(3, slot) - - -class NodeNameChecksTest(test_util.TensorFlowTestCase): - - def testIsCopyNode(self): - self.assertTrue(debug_data.is_copy_node("__copy_ns1/ns2/node3_0")) - - self.assertFalse(debug_data.is_copy_node("copy_ns1/ns2/node3_0")) - self.assertFalse(debug_data.is_copy_node("_copy_ns1/ns2/node3_0")) - self.assertFalse(debug_data.is_copy_node("_copyns1/ns2/node3_0")) - self.assertFalse(debug_data.is_copy_node("__dbg_ns1/ns2/node3_0")) - - def testIsDebugNode(self): - self.assertTrue( - debug_data.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity")) - - self.assertFalse( - debug_data.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity")) - self.assertFalse( - debug_data.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity")) - self.assertFalse( - debug_data.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity")) - self.assertFalse(debug_data.is_debug_node("__copy_ns1/ns2/node3_0")) - - -class ParseDebugNodeNameTest(test_util.TensorFlowTestCase): - - def testParseDebugNodeName_valid(self): - debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity" - (watched_node, watched_output_slot, debug_op_index, - debug_op) = debug_data.parse_debug_node_name(debug_node_name_1) - - self.assertEqual("ns_a/ns_b/node_c", watched_node) - self.assertEqual(1, watched_output_slot) - self.assertEqual(0, debug_op_index) - self.assertEqual("DebugIdentity", debug_op) - - def testParseDebugNodeName_invalidPrefix(self): - invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity" - - with self.assertRaisesRegexp(ValueError, "Invalid prefix"): - debug_data.parse_debug_node_name(invalid_debug_node_name_1) - - def testParseDebugNodeName_missingDebugOpIndex(self): - invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity" - - with self.assertRaisesRegexp(ValueError, "Invalid debug node name"): - debug_data.parse_debug_node_name(invalid_debug_node_name_1) - - def testParseDebugNodeName_invalidWatchedTensorName(self): - invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity" - - with self.assertRaisesRegexp(ValueError, - "Invalid tensor name in debug node name"): - debug_data.parse_debug_node_name(invalid_debug_node_name_1) - - class HasNanOrInfTest(test_util.TensorFlowTestCase): def setUp(self): @@ -375,19 +304,5 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase): fake.assert_has_calls(expected_calls, any_order=True) -class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase): - - def testParseTensorNameInputWorks(self): - self.assertEqual("a", debug_data.get_node_name("a:0")) - self.assertEqual(0, debug_data.get_output_slot("a:0")) - - self.assertEqual("_b", debug_data.get_node_name("_b:1")) - self.assertEqual(1, debug_data.get_output_slot("_b:1")) - - def testParseNodeNameInputWorks(self): - self.assertEqual("a", debug_data.get_node_name("a")) - self.assertEqual(0, debug_data.get_output_slot("a")) - - if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py index 5306391613..b01a58719c 100644 --- a/tensorflow/python/debug/lib/debug_gradients.py +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -24,6 +24,7 @@ import uuid import six from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.framework import ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables @@ -34,7 +35,7 @@ _gradient_debuggers = {} def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid): - op_name, slot = debug_data.parse_node_or_tensor_name(tensor.name) + op_name, slot = debug_graphs.parse_node_or_tensor_name(tensor.name) return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid) @@ -407,7 +408,7 @@ def gradient_values_from_dump(grad_debugger, x_tensor, dump): (grad_debugger.graph, dump.python_graph)) gradient_tensor = grad_debugger.gradient_tensor(x_tensor) - node_name, output_slot = debug_data.parse_node_or_tensor_name( + node_name, output_slot = debug_graphs.parse_node_or_tensor_name( gradient_tensor.name) try: diff --git a/tensorflow/python/debug/lib/debug_graphs.py b/tensorflow/python/debug/lib/debug_graphs.py new file mode 100644 index 0000000000..20e2a6acfe --- /dev/null +++ b/tensorflow/python/debug/lib/debug_graphs.py @@ -0,0 +1,430 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes and methods for processing debugger-decorated graphs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import op_def_registry + + +def parse_node_or_tensor_name(name): + """Get the node name from a string that can be node or tensor name. + + Args: + name: An input node name (e.g., "node_a") or tensor name (e.g., + "node_a:0"), as a str. + + Returns: + 1) The node name, as a str. If the input name is a tensor name, i.e., + consists of a colon, the final colon and the following output slot + will be stripped. + 2) If the input name is a tensor name, the output slot, as an int. If + the input name is not a tensor name, None. + """ + + if ":" in name and not name.endswith(":"): + node_name = name[:name.rfind(":")] + output_slot = int(name[name.rfind(":") + 1:]) + + return node_name, output_slot + else: + return name, None + + +def get_node_name(element_name): + node_name, _ = parse_node_or_tensor_name(element_name) + return node_name + + +def get_output_slot(element_name): + """Get the output slot number from the name of a graph element. + + If element_name is a node name without output slot at the end, 0 will be + assumed. + + Args: + element_name: (`str`) name of the graph element in question. + + Returns: + (`int`) output slot number. + """ + _, output_slot = parse_node_or_tensor_name(element_name) + return output_slot if output_slot is not None else 0 + + +def is_copy_node(node_name): + """Determine whether a node name is that of a debug Copy node. + + Such nodes are inserted by TensorFlow core upon request in + RunOptions.debug_options.debug_tensor_watch_opts. + + Args: + node_name: Name of the node. + + Returns: + A bool indicating whether the input argument is the name of a debug Copy + node. + """ + return node_name.startswith("__copy_") + + +def is_debug_node(node_name): + """Determine whether a node name is that of a debug node. + + Such nodes are inserted by TensorFlow core upon request in + RunOptions.debug_options.debug_tensor_watch_opts. + + Args: + node_name: Name of the node. + + Returns: + A bool indicating whether the input argument is the name of a debug node. + """ + return node_name.startswith("__dbg_") + + +def parse_debug_node_name(node_name): + """Parse the name of a debug node. + + Args: + node_name: Name of the debug node. + + Returns: + 1. Name of the watched node, as a str. + 2. Output slot index of the watched tensor, as an int. + 3. Index of the debug node, as an int. + 4. Name of the debug op, as a str, e.g, "DebugIdentity". + + Raises: + ValueError: If the input node name is not a valid debug node name. + """ + prefix = "__dbg_" + + name = node_name + if not name.startswith(prefix): + raise ValueError("Invalid prefix in debug node name: '%s'" % node_name) + + name = name[len(prefix):] + + if name.count("_") < 2: + raise ValueError("Invalid debug node name: '%s'" % node_name) + + debug_op = name[name.rindex("_") + 1:] + name = name[:name.rindex("_")] + + debug_op_index = int(name[name.rindex("_") + 1:]) + name = name[:name.rindex("_")] + + if name.count(":") != 1: + raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name) + + watched_node_name = name[:name.index(":")] + watched_output_slot = int(name[name.index(":") + 1:]) + + return watched_node_name, watched_output_slot, debug_op_index, debug_op + + +class GraphTracingReachedDestination(Exception): + pass + + +class DFSGraphTracer(object): + """Graph input tracer using depth-first search.""" + + def __init__(self, + input_lists, + skip_node_names=None, + destination_node_name=None): + """Constructor of _DFSGraphTracer. + + Args: + input_lists: A list of dicts. Each dict is an adjacency (input) map from + the recipient node name as the key and the list of input node names + as the value. + skip_node_names: Optional: a list of node names to skip tracing. + destination_node_name: Optional: destination node name. If not `None`, it + should be the name of a destination not as a str and the graph tracing + will raise GraphTracingReachedDestination as soon as the node has been + reached. + + Raises: + GraphTracingReachedDestination: if stop_at_node_name is not None and + the specified node is reached. + """ + + self._input_lists = input_lists + self._skip_node_names = skip_node_names + + self._inputs = [] + self._visited_nodes = [] + self._depth_count = 0 + self._depth_list = [] + + self._destination_node_name = destination_node_name + + def trace(self, graph_element_name): + """Trace inputs. + + Args: + graph_element_name: Name of the node or an output tensor of the node, as a + str. + + Raises: + GraphTracingReachedDestination: if destination_node_name of this tracer + object is not None and the specified node is reached. + """ + self._depth_count += 1 + + node_name = get_node_name(graph_element_name) + if node_name == self._destination_node_name: + raise GraphTracingReachedDestination() + + if node_name in self._skip_node_names: + return + if node_name in self._visited_nodes: + return + + self._visited_nodes.append(node_name) + + for input_list in self._input_lists: + for inp in input_list[node_name]: + if get_node_name(inp) in self._visited_nodes: + continue + self._inputs.append(inp) + self._depth_list.append(self._depth_count) + self.trace(inp) + + self._depth_count -= 1 + + def inputs(self): + return self._inputs + + def depth_list(self): + return self._depth_list + + +class DebugGraph(object): + """Represents a debugger-decorated graph.""" + + def __init__(self, debug_graph_def, device_name=None): + self._debug_graph_def = debug_graph_def + + self._node_attributes = {} + self._node_inputs = {} + self._node_reversed_ref_inputs = {} + self._node_ctrl_inputs = {} + self._node_recipients = {} + self._node_ctrl_recipients = {} + self._node_devices = {} + self._node_op_types = {} + self._copy_send_nodes = [] + self._ref_args = {} + + self._device_name = device_name + if not self._device_name and debug_graph_def.node: + self._device_name = debug_graph_def.node[0].device + + for node in debug_graph_def.node: + self._process_debug_graph_node(node) + + self._prune_non_control_edges_of_debug_ops() + self._prune_control_edges_of_debug_ops() + + self._populate_recipient_maps() + + def _process_debug_graph_node(self, node): + """Process a node from the debug GraphDef. + + Args: + node: (NodeDef) A partition-graph node to be processed. + + Raises: + ValueError: If duplicate node names are encountered. + """ + + if is_debug_node(node.name): + # This is a debug node. Parse the node name and retrieve the + # information about debug watches on tensors. But do not include + # the node in the graph. + return + + if node.name in self._node_inputs: + raise ValueError("Duplicate node name on device %s: '%s'" % + (self._device_name, node.name)) + + self._node_attributes[node.name] = node.attr + + self._node_inputs[node.name] = [] + self._node_ctrl_inputs[node.name] = [] + self._node_recipients[node.name] = [] + self._node_ctrl_recipients[node.name] = [] + + if node.name not in self._node_devices: + self._node_devices[node.name] = set() + self._node_devices[node.name].add(node.device) + self._node_op_types[node.name] = node.op + self._ref_args[node.name] = self._get_ref_args(node) + + for inp in node.input: + if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): + self._copy_send_nodes.append(node.name) + + if inp.startswith("^"): + cinp = inp[1:] + self._node_ctrl_inputs[node.name].append(cinp) + else: + self._node_inputs[node.name].append(inp) + + def _get_ref_args(self, node): + """Determine whether an input of an op is ref-type. + + Args: + node: A `NodeDef`. + + Returns: + A list of the arg names (as strs) that are ref-type. + """ + op_def = op_def_registry.get_registered_ops().get(node.op) + ref_args = [] + if op_def: + for i, output_arg in enumerate(op_def.output_arg): + if output_arg.is_ref: + arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) + ref_args.append(arg_name) + return ref_args + + def _prune_non_control_edges_of_debug_ops(self): + """Prune (non-control) edges related to debug ops. + + Prune the Copy ops and associated _Send ops inserted by the debugger out + from the non-control inputs and output recipients map. Replace the inputs + and recipients with original ones. + """ + copy_nodes = [] + for node in self._node_inputs: + if node in self._copy_send_nodes: + continue + + if is_copy_node(node): + copy_nodes.append(node) + + inputs = self._node_inputs[node] + + for i in xrange(len(inputs)): + inp = inputs[i] + if is_copy_node(inp): + # Find the input to the Copy node, which should be the original + # input to the node. + orig_inp = self._node_inputs[inp][0] + inputs[i] = orig_inp + + self._prune_nodes_from_input_and_recipient_maps(copy_nodes) + self._prune_nodes_from_input_and_recipient_maps(self._copy_send_nodes) + + def _prune_control_edges_of_debug_ops(self): + """Prune control edges related to the debug ops.""" + for node in self._node_ctrl_inputs: + ctrl_inputs = self._node_ctrl_inputs[node] + debug_op_inputs = [] + for ctrl_inp in ctrl_inputs: + if is_debug_node(ctrl_inp): + debug_op_inputs.append(ctrl_inp) + for debug_op_inp in debug_op_inputs: + ctrl_inputs.remove(debug_op_inp) + + def _populate_recipient_maps(self): + """Populate the map from node name to recipient(s) of its output(s). + + This method also populates the input map based on reversed ref edges. + """ + for node in self._node_inputs: + inputs = self._node_inputs[node] + for inp in inputs: + inp = get_node_name(inp) + if inp not in self._node_recipients: + self._node_recipients[inp] = [] + self._node_recipients[inp].append(node) + + if inp in self._ref_args: + if inp not in self._node_reversed_ref_inputs: + self._node_reversed_ref_inputs[inp] = [] + self._node_reversed_ref_inputs[inp].append(node) + + for node in self._node_ctrl_inputs: + ctrl_inputs = self._node_ctrl_inputs[node] + for ctrl_inp in ctrl_inputs: + if ctrl_inp in self._copy_send_nodes: + continue + + if ctrl_inp not in self._node_ctrl_recipients: + self._node_ctrl_recipients[ctrl_inp] = [] + self._node_ctrl_recipients[ctrl_inp].append(node) + + def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune): + """Prune nodes out of input and recipient maps. + + Args: + nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned. + """ + for node in nodes_to_prune: + del self._node_inputs[node] + del self._node_ctrl_inputs[node] + del self._node_recipients[node] + del self._node_ctrl_recipients[node] + + @property + def device_name(self): + return self._device_name + + @property + def debug_graph_def(self): + """The debugger-decorated GraphDef.""" + return self._debug_graph_def + + @property + def node_devices(self): + return self._node_devices + + @property + def node_op_types(self): + return self._node_op_types + + @property + def node_attributes(self): + return self._node_attributes + + @property + def node_inputs(self): + return self._node_inputs + + @property + def node_ctrl_inputs(self): + return self._node_ctrl_inputs + + @property + def node_reversed_ref_inputs(self): + return self._node_reversed_ref_inputs + + @property + def node_recipients(self): + return self._node_recipients + + @property + def node_ctrl_recipients(self): + return self._node_ctrl_recipients diff --git a/tensorflow/python/debug/lib/debug_graphs_test.py b/tensorflow/python/debug/lib/debug_graphs_test.py new file mode 100644 index 0000000000..34257794f1 --- /dev/null +++ b/tensorflow/python/debug/lib/debug_graphs_test.py @@ -0,0 +1,112 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfdbg module debug_data.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.debug.lib import debug_graphs +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase): + + def testParseNodeName(self): + node_name, slot = debug_graphs.parse_node_or_tensor_name( + "namespace1/node_1") + + self.assertEqual("namespace1/node_1", node_name) + self.assertIsNone(slot) + + def testParseTensorName(self): + node_name, slot = debug_graphs.parse_node_or_tensor_name( + "namespace1/node_2:3") + + self.assertEqual("namespace1/node_2", node_name) + self.assertEqual(3, slot) + + +class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase): + + def testParseTensorNameInputWorks(self): + self.assertEqual("a", debug_graphs.get_node_name("a:0")) + self.assertEqual(0, debug_graphs.get_output_slot("a:0")) + + self.assertEqual("_b", debug_graphs.get_node_name("_b:1")) + self.assertEqual(1, debug_graphs.get_output_slot("_b:1")) + + def testParseNodeNameInputWorks(self): + self.assertEqual("a", debug_graphs.get_node_name("a")) + self.assertEqual(0, debug_graphs.get_output_slot("a")) + + +class NodeNameChecksTest(test_util.TensorFlowTestCase): + + def testIsCopyNode(self): + self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0")) + + self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0")) + self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0")) + self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0")) + self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0")) + + def testIsDebugNode(self): + self.assertTrue( + debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity")) + + self.assertFalse( + debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity")) + self.assertFalse( + debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity")) + self.assertFalse( + debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity")) + self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0")) + + +class ParseDebugNodeNameTest(test_util.TensorFlowTestCase): + + def testParseDebugNodeName_valid(self): + debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity" + (watched_node, watched_output_slot, debug_op_index, + debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1) + + self.assertEqual("ns_a/ns_b/node_c", watched_node) + self.assertEqual(1, watched_output_slot) + self.assertEqual(0, debug_op_index) + self.assertEqual("DebugIdentity", debug_op) + + def testParseDebugNodeName_invalidPrefix(self): + invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity" + + with self.assertRaisesRegexp(ValueError, "Invalid prefix"): + debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) + + def testParseDebugNodeName_missingDebugOpIndex(self): + invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity" + + with self.assertRaisesRegexp(ValueError, "Invalid debug node name"): + debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) + + def testParseDebugNodeName_invalidWatchedTensorName(self): + invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity" + + with self.assertRaisesRegexp(ValueError, + "Invalid tensor name in debug node name"): + debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/debug/lib/grpc_debug_server.py b/tensorflow/python/debug/lib/grpc_debug_server.py index 309fdb3bce..28cf1514f6 100644 --- a/tensorflow/python/debug/lib/grpc_debug_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_server.py @@ -29,7 +29,7 @@ from six.moves import queue from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.framework import graph_pb2 -from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.platform import tf_logging as logging @@ -294,10 +294,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): def _process_graph_def(self, graph_def): for node_def in graph_def.node: - if (debug_data.is_debug_node(node_def.name) and + if (debug_graphs.is_debug_node(node_def.name) and node_def.attr["gated_grpc"].b): node_name, output_slot, _, debug_op = ( - debug_data.parse_debug_node_name(node_def.name)) + debug_graphs.parse_debug_node_name(node_def.name)) self._gated_grpc_debug_watches.add( DebugWatch(node_name, output_slot, debug_op)) diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py index 48f31771db..aa5314dda5 100644 --- a/tensorflow/python/debug/lib/session_debug_file_test.py +++ b/tensorflow/python/debug/lib/session_debug_file_test.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class SessionDebugTest(session_debug_testlib.SessionDebugTestBase): +class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 08b3e75e7c..d4b9d06b54 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -33,6 +33,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.debug.lib import debug_utils from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -242,7 +243,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): v_copy_node_def = None for partition_graph in run_metadata.partition_graphs: for node_def in partition_graph.node: - if debug_data.is_copy_node(node_def.name): + if debug_graphs.is_copy_node(node_def.name): if node_def.name == "__copy_u_0": u_copy_node_def = node_def elif node_def.name == "__copy_v_0": diff --git a/tensorflow/python/debug/lib/stepper.py b/tensorflow/python/debug/lib/stepper.py index c814520b7e..1fa0b3dba2 100644 --- a/tensorflow/python/debug/lib/stepper.py +++ b/tensorflow/python/debug/lib/stepper.py @@ -27,6 +27,7 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.debug.lib import debug_utils from tensorflow.python.framework import ops from tensorflow.python.ops import session_ops @@ -706,8 +707,8 @@ class NodeStepper(object): if ":" in element_name: debug_utils.add_debug_tensor_watch( run_options, - debug_data.get_node_name(element_name), - output_slot=debug_data.get_output_slot(element_name), + debug_graphs.get_node_name(element_name), + output_slot=debug_graphs.get_output_slot(element_name), debug_urls=["file://" + dump_path]) return dump_path, run_options @@ -961,5 +962,5 @@ class NodeStepper(object): The node associated with element in the graph. """ - node_name, _ = debug_data.parse_node_or_tensor_name(element.name) + node_name, _ = debug_graphs.parse_node_or_tensor_name(element.name) return self._sess.graph.as_graph_element(node_name) |