diff options
author | 2017-03-30 20:27:38 -0800 | |
---|---|---|
committer | 2017-03-30 21:48:33 -0700 | |
commit | 99559a3455d7c3b3663968a55bf56801743201d0 (patch) | |
tree | b5a57628ea1479cf8da74e7edcaaabf0e34e2866 /tensorflow/python/debug/lib | |
parent | fb56fc90167c3919cb59f753f233ef2a41469cb2 (diff) |
tfdbg: fix a bug in graph validation related to tf.while_loops
CL/147488620 fixed a bug where the debugger would hang at Enter and NextIteration nodes under certain conditions. But it introduced another bug where the debug dumps from Enter and NextIteration may get generated later than downstream nodes in the tf.while_loop body, causing "causility violation" during debug_data.DebugDumpDir's validation process under certain conditions (e.g., backpropagation on a dynamic_rnn). This CL fixes that by excluding Enter and NextIteration nodes from the validation process.
Fixes: #8337
Change: 151787432
Diffstat (limited to 'tensorflow/python/debug/lib')
-rw-r--r-- | tensorflow/python/debug/lib/debug_data.py | 4 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_testlib.py | 58 |
2 files changed, 62 insertions, 0 deletions
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 9d0b360cde..96772aed8f 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -934,8 +934,12 @@ class DebugDumpDir(object): for inp in inputs: inp_node = get_node_name(inp) inp_output_slot = get_output_slot(inp) + # Inputs from Enter and NextIteration nodes are not validated because + # DebugNodeInserter::InsertNodes() in the debugger core skips creating + # control edges from debug ops watching these types of nodes. if (inp_node in self._debug_watches and inp_output_slot in self._debug_watches[inp_node] and + self._node_op_types.get(inp) not in ("Enter", "NextIteration") and (inp_node, inp_output_slot) not in pending_inputs[node]): pending_inputs[node].append((inp_node, inp_output_slot)) diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 42a7e6d0a3..511ddb1673 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -43,13 +43,36 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent +class _RNNCellForTest(rnn_cell_impl._RNNCell): # pylint: disable=protected-access + """RNN cell for testing.""" + + def __init__(self, input_output_size, state_size): + self._input_output_size = input_output_size + self._state_size = state_size + self._w = variables.Variable(1.0, dtype=dtypes.float32, name="w") + + @property + def output_size(self): + return self._input_output_size + + @property + def state_size(self): + return self._state_size + + def __call__(self, input_, state, scope=None): + return (math_ops.multiply(self._w, input_), state) + + class SessionDebugTestBase(test_util.TensorFlowTestCase): """Base class for unit tests of tfdbg running with tf.Session.""" @@ -436,6 +459,41 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): [[12], [14], [16]], dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) + def testDebugTrainingDynamicRNNWorks(self): + with session.Session() as sess: + input_size = 3 + state_size = 2 + time_steps = 4 + batch_size = 2 + + input_values = np.random.randn(time_steps, batch_size, input_size) + sequence_length = np.random.randint(0, time_steps, size=batch_size) + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + + outputs_dynamic, _ = rnn.dynamic_rnn( + _RNNCellForTest(input_size, state_size), + inputs=concat_inputs, + sequence_length=sequence_length, + time_major=True, + dtype=dtypes.float32) + toy_loss = math_ops.reduce_sum(outputs_dynamic * outputs_dynamic) + train_op = gradient_descent.GradientDescentOptimizer( + learning_rate=0.1).minimize(toy_loss, name="train_op") + + sess.run(variables.global_variables_initializer()) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph(run_options, + sess.graph, + debug_urls=self._debug_urls()) + run_metadata = config_pb2.RunMetadata() + sess.run(train_op, feed_dict={concat_inputs: input_values}, + options=run_options, run_metadata=run_metadata) + + debug_data.DebugDumpDir( + self._dump_root, partition_graphs=run_metadata.partition_graphs) + def testDebugCondWatchingWholeGraphWorks(self): with session.Session() as sess: x = variables.Variable(10.0, name="x") |