aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-06-28 14:23:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 14:26:57 -0700
commit901d82d736988e8dbc47385cfde1d97cdc69ba26 (patch)
tree31c85435954d45f3024a6fcb5848abd3db31ea9a /tensorflow/python/debug
parent1a4820e1c52569e89d48141cf328cfe35c23a98a (diff)
tfdbg: Fix compatibility with C++ MakeCallable and _make_callable_from_options
Fixes #20160 REL_NOTES: tfdbg: Fix compatibility with `tf.keras.Model`s training on `tf.data.Dataset`s. PiperOrigin-RevId: 202543231
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD13
-rw-r--r--tensorflow/python/debug/examples/debug_keras.py89
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh7
-rw-r--r--tensorflow/python/debug/wrappers/framework.py87
-rw-r--r--tensorflow/python/debug/wrappers/grpc_wrapper.py6
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py2
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py118
7 files changed, 299 insertions, 23 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 6941cacf23..c025dc8aa5 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -454,6 +454,17 @@ py_binary(
],
)
+py_binary(
+ name = "debug_keras",
+ srcs = ["examples/debug_keras.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_py",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
py_test(
name = "common_test",
size = "small",
@@ -1086,6 +1097,7 @@ py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//third_party/py/numpy",
],
)
@@ -1096,6 +1108,7 @@ sh_test(
data = [
":debug_errors",
":debug_fibonacci",
+ ":debug_keras",
":debug_mnist",
":debug_tflearn_iris",
":offline_analyzer",
diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py
new file mode 100644
index 0000000000..3272d85ade
--- /dev/null
+++ b/tensorflow/python/debug/examples/debug_keras.py
@@ -0,0 +1,89 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tfdbg example: debugging tf.keras models training on tf.data.Dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python import debug as tf_debug
+
+
+def main(_):
+ # Create a dummy dataset.
+ num_examples = 8
+ steps_per_epoch = 2
+ input_dims = 3
+ output_dims = 1
+ xs = np.zeros([num_examples, input_dims])
+ ys = np.zeros([num_examples, output_dims])
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch))
+
+ sess = tf.Session()
+ if FLAGS.debug:
+ # Use the command-line interface (CLI) of tfdbg.
+ sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
+ elif FLAGS.tensorboard_debug_address:
+ # Use the TensorBoard Debugger Plugin (GUI of tfdbg).
+ sess = tf_debug.TensorBoardDebugWrapperSession(
+ sess, FLAGS.tensorboard_debug_address)
+ tf.keras.backend.set_session(sess)
+
+ # Create a dummy model.
+ model = tf.keras.Sequential([
+ tf.keras.layers.Dense(1, input_shape=[input_dims])])
+ model.compile(loss="mse", optimizer="sgd")
+
+ # Train the model using the dummy dataset created above.
+ model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training. "
+ "Mutually exclusive with the --tensorboard_debug_address flag.")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses",
+ help="Command-line user interface type (curses | readline).")
+ parser.add_argument(
+ "--tensorboard_debug_address",
+ type=str,
+ default=None,
+ help="Connect to the TensorBoard Debugger Plugin backend specified by "
+ "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+ "--debug flag.")
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=2,
+ help="Number of epochs to train the model for.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index e9c45a7e6e..2d35b2d8bb 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then
DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors"
DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris"
+ DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras"
OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer"
else
DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci"
DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors"
DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris"
+ DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras"
OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer"
fi
@@ -96,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then
exit 1
fi
+# Test debugging of tf.keras.
+cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline
+run -f has_inf_or_nan
+EOF
+
# Test offline_analyzer.
echo
echo "Testing offline_analyzer"
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index c530204bbf..b9524ce649 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
self._default_session_context_manager = None
+ # A cache for callables created from CallableOptions.
+ self._cached_callables_from_options = dict()
+
@property
def graph(self):
return self._sess.graph
@@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
"""Wrapper around Session.run() that inserts tensor watch options.
Args:
@@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
callable_runner: A `callable` returned by `Session.make_callable()`.
If not `None`, `fetches` and `feed_dict` must both be `None`.
- callable_runner_args: An optional list of arguments to `callable_runner`.
+ Mutually exclusive with `callable_options`.
+ callable_runner_args: An optional list of arguments to `callable_runner`
+ or for `callable_options`.
+ callable_options: An instance of `config_pb2.CallableOptions`, to be
+ used with `Session._make_callable_from_options()`. Mutually exclusive
+ with `callable_runner`.
Returns:
Simply forwards the output of the wrapped `Session.run()` call.
@@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface):
ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
is not `None` and either or both of `fetches` and `feed_dict` is `None`.
"""
- if not callable_runner:
+ if callable_runner and callable_options:
+ raise ValueError(
+ "callable_runner and callable_options are mutually exclusive, but "
+ "are both specified in this call to BaseDebugWrapperSession.run().")
+
+ if not (callable_runner or callable_options):
self.increment_run_call_count()
- else:
- if fetches or feed_dict:
- raise ValueError(
- "callable_runner and fetches/feed_dict are mutually exclusive, but "
- "are used simultaneously.")
+ elif callable_runner and (fetches or feed_dict):
+ raise ValueError(
+ "callable_runner and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
@@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
+ elif callable_options:
+ # pylint:disable=protected-access
+ return self._sess._make_callable_from_options(
+ callable_options)(*callable_runner_args)
+ # pylint:enable=protected-access
else:
return self._sess.run(fetches,
feed_dict=feed_dict,
@@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface):
if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
# Decorate RunOption to fill in debugger tensor watch specifications.
- decorated_run_options = options or config_pb2.RunOptions()
+ decorated_run_options = None
+ if callable_options:
+ callable_options_id = id(callable_options)
+ if callable_options_id not in self._cached_callables_from_options:
+ # Make a copy of callable_options to avoid mutating it.
+ new_callable_options = config_pb2.CallableOptions()
+ new_callable_options.CopyFrom(callable_options)
+ decorated_run_options = new_callable_options.run_options
+ else:
+ decorated_run_options = options or config_pb2.RunOptions()
+
run_metadata = run_metadata or config_pb2.RunMetadata()
- self._decorate_run_options_for_debug(
- decorated_run_options,
- run_start_resp.debug_urls,
- debug_ops=run_start_resp.debug_ops,
- node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
- op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
- tensor_dtype_regex_whitelist=(
- run_start_resp.tensor_dtype_regex_whitelist),
- tolerate_debug_op_creation_failures=(
- run_start_resp.tolerate_debug_op_creation_failures))
+ if decorated_run_options:
+ self._decorate_run_options_for_debug(
+ decorated_run_options,
+ run_start_resp.debug_urls,
+ debug_ops=run_start_resp.debug_ops,
+ node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
+ op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=(
+ run_start_resp.tensor_dtype_regex_whitelist),
+ tolerate_debug_op_creation_failures=(
+ run_start_resp.tolerate_debug_op_creation_failures))
# Invoke the run() method of the wrapped Session. Catch any TensorFlow
# runtime errors.
@@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface):
retvals = callable_runner(*callable_runner_args,
options=decorated_run_options,
run_metadata=run_metadata)
+ elif callable_options:
+ # pylint:disable=protected-access
+ if callable_options_id in self._cached_callables_from_options:
+ callable_object = self._cached_callables_from_options[
+ callable_options_id]
+ else:
+ callable_object = self._sess._make_callable_from_options(
+ new_callable_options)
+ self._cached_callables_from_options[
+ callable_options_id] = callable_object
+ # pylint:enable=protected-access
+ retvals = callable_object(
+ *callable_runner_args, run_metadata=run_metadata)
else:
retvals = self._sess.run(fetches,
feed_dict=feed_dict,
@@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata=kwargs.get("run_metadata", None),
callable_runner=runner,
callable_runner_args=runner_args)
+ return wrapped_runner
+ def _make_callable_from_options(self, callable_options):
+ def wrapped_runner(*feed_values, **kwargs):
+ return self.run(None,
+ run_metadata=kwargs.get("run_metadata", None),
+ callable_options=callable_options,
+ callable_runner_args=feed_values)
return wrapped_runner
@property
diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py
index 1f9c8fa5a9..85944fa611 100644
--- a/tensorflow/python/debug/wrappers/grpc_wrapper.py
+++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py
@@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
if self._send_traceback_and_source_code:
self._sent_graph_version = publish_traceback(
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
@@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=options,
run_metadata=run_metadata,
callable_runner=callable_runner,
- callable_runner_args=callable_runner_args)
+ callable_runner_args=callable_runner_args,
+ callable_options=callable_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 4e551ab995..668ffb57f1 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -596,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
# Register tab completion for the filter names.
curses_cli.register_tab_comp_context(["run", "r"],
list(self._tensor_filters.keys()))
- if self._feed_dict:
+ if self._feed_dict and hasattr(self._feed_dict, "keys"):
# Register tab completion for feed_dict keys.
feed_keys = [common.get_graph_element_name(key)
for key in self._feed_dict.keys()]
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index b06fa26a93..05c9eaa4d2 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -21,7 +21,10 @@ import os
import shutil
import tempfile
+import numpy as np
+
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import debugger_cli_common
@@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config_proto = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config_proto)
# Initialize variable.
self.sess.run(variables.global_variables_initializer())
@@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertAllClose(42.0, tensor_runner(41.0, 1.0))
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
+ def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ for _ in range(2):
+ callable_output = sess_callable()
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ a = math_ops.add(ph1, ph1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array([10.5, -10.5], dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value)
+ self.assertAllClose(
+ np.array([42.0, -42.0], dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2")
+ a = math_ops.add(ph1, ph2, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.feed.append("callable_ph2")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array(5.0, dtype=np.float32)
+ ph2_value = np.array(16.0, dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value, ph2_value)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
+
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ run_metadata = config_pb2.RunMetadata()
+ # Call the callable with a custom run_metadata.
+ callable_output = sess_callable(run_metadata=run_metadata)
+ # Verify that step_stats is populated in the custom run_metadata.
+ self.assertTrue(run_metadata.step_stats)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(1, len(debug_dumps))
+ debug_dump = debug_dumps[0]
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
def testRuntimeErrorShouldBeCaught(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)