diff options
author | 2017-04-21 13:39:05 -0800 | |
---|---|---|
committer | 2017-04-21 14:52:23 -0700 | |
commit | ed6b1578090c8914042f9d6b2594d13d21bde213 (patch) | |
tree | f4528f21174bcbd65e3de774eab3aacc368c8101 /tensorflow | |
parent | 3bf7bc9ebb3168791d8f217f46f6413888ccea92 (diff) |
Add --tf_debug option flag to save_model_cli
Change: 153873095
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/tools/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/tools/saved_model_cli.py | 23 | ||||
-rw-r--r-- | tensorflow/python/tools/saved_model_cli_test.py | 50 |
3 files changed, 63 insertions, 11 deletions
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index eaf0a5c837..48b84f9a96 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -205,6 +205,7 @@ py_binary( deps = [ "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python", + "//tensorflow/python/debug:local_cli_wrapper", ], ) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index d14748b492..17ef8ef9c2 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -66,6 +66,12 @@ tensors to files: --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy --outdir /tmp/out +To observe the intermediate Tensor values in the runtime graph, use the +--tf_debug flag, e.g.: + $saved_model_cli run --dir /tmp/saved_model --tag_set serve + --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy + --outdir /tmp/out --tf_debug + To build this tool from source, run: $bazel build tensorflow/python/tools:saved_model_cli @@ -87,6 +93,7 @@ from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session +from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import ops as ops_lib from tensorflow.python.platform import app from tensorflow.python.saved_model import loader @@ -282,7 +289,7 @@ def get_signature_def_map(saved_model_dir, tag_set): def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, - overwrite_flag): + overwrite_flag, tf_debug=False): """Runs SavedModel and fetch all outputs. Runs the input dictionary through the MetaGraphDef within a SavedModel @@ -300,6 +307,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, it will be created. overwrite_flag: A boolean flag to allow overwrite output file if file with the same name exists. + tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the + intermediate Tensor values and runtime GraphDefs while running the + SavedModel. Raises: RuntimeError: An error when output file already exists and overwrite is not @@ -329,6 +339,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, with session.Session(graph=ops_lib.Graph()) as sess: loader.load(sess, tag_set.split(','), saved_model_dir) + if tf_debug: + sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess) + outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict) for i, output in enumerate(outputs): @@ -520,7 +533,7 @@ def run(args): tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs) run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, tensor_key_feed_dict, args.outdir, - args.overwrite) + args.overwrite, tf_debug=args.tf_debug) def create_parser(): @@ -620,6 +633,12 @@ def create_parser(): '--overwrite', action='store_true', help='if set, output file will be overwritten if it already exists.') + parser_run.add_argument( + '--tf_debug', + action='store_true', + help='if set, will use TensorFlow Debugger (tfdbg) to watch the ' + 'intermediate Tensors and runtime GraphDefs while running the ' + 'SavedModel.') parser_run.set_defaults(func=run) return parser diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index b9d28794cc..c481dba2e9 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -28,6 +28,7 @@ import sys import numpy as np from six import StringIO +from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.platform import test from tensorflow.python.tools import saved_model_cli @@ -299,9 +300,9 @@ Method name is: tensorflow/serving/predict""" test.get_temp_dir() ]) saved_model_cli.run(args) - y = np.load(output_file) - y_exp = np.array([[3.5], [4.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(output_file) + y_expected = np.array([[3.5], [4.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandNewOutdir(self): self.parser = saved_model_cli.create_parser() @@ -320,9 +321,9 @@ Method name is: tensorflow/serving/predict""" output_dir ]) saved_model_cli.run(args) - y = np.load(os.path.join(output_dir, 'y.npy')) - y_exp = np.array([[2.5], [3.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(os.path.join(output_dir, 'y.npy')) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandOutOverwrite(self): self.parser = saved_model_cli.create_parser() @@ -340,9 +341,9 @@ Method name is: tensorflow/serving/predict""" test.get_temp_dir(), '--overwrite' ]) saved_model_cli.run(args) - y = np.load(output_file) - y_exp = np.array([[2.5], [3.0]]) - self.assertTrue(np.allclose(y, y_exp)) + y_actual = np.load(output_file) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) def testRunCommandOutputFileExistError(self): self.parser = saved_model_cli.create_parser() @@ -362,6 +363,37 @@ Method name is: tensorflow/serving/predict""" with self.assertRaises(RuntimeError): saved_model_cli.run(args) + def testRunCommandWithDebuggerEnabled(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + x = np.array([[1], [2]]) + x_notused = np.zeros((6, 3)) + input_path = os.path.join(test.get_temp_dir(), + 'testRunCommandNewOutdir_inputs.npz') + output_dir = os.path.join(test.get_temp_dir(), 'new_dir') + if os.path.isdir(output_dir): + shutil.rmtree(output_dir) + np.savez(input_path, x0=x, x1=x_notused) + args = self.parser.parse_args([ + 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def', + 'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir', + output_dir, '--tf_debug' + ]) + + def fake_wrapper_session(sess): + return sess + + with test.mock.patch.object(local_cli_wrapper, + 'LocalCLIDebugWrapperSession', + side_effect=fake_wrapper_session, + autospec=True) as fake: + saved_model_cli.run(args) + fake.assert_called_with(test.mock.ANY) + + y_actual = np.load(os.path.join(output_dir, 'y.npy')) + y_expected = np.array([[2.5], [3.0]]) + self.assertAllClose(y_expected, y_actual) + if __name__ == '__main__': test.main() |