aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-04-21 13:39:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-21 14:52:23 -0700
commited6b1578090c8914042f9d6b2594d13d21bde213 (patch)
treef4528f21174bcbd65e3de774eab3aacc368c8101 /tensorflow
parent3bf7bc9ebb3168791d8f217f46f6413888ccea92 (diff)
Add --tf_debug option flag to save_model_cli
Change: 153873095
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/tools/BUILD1
-rw-r--r--tensorflow/python/tools/saved_model_cli.py23
-rw-r--r--tensorflow/python/tools/saved_model_cli_test.py50
3 files changed, 63 insertions, 11 deletions
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index eaf0a5c837..48b84f9a96 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -205,6 +205,7 @@ py_binary(
deps = [
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python",
+ "//tensorflow/python/debug:local_cli_wrapper",
],
)
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index d14748b492..17ef8ef9c2 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -66,6 +66,12 @@ tensors to files:
--signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
--outdir /tmp/out
+To observe the intermediate Tensor values in the runtime graph, use the
+--tf_debug flag, e.g.:
+ $saved_model_cli run --dir /tmp/saved_model --tag_set serve
+ --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
+ --outdir /tmp/out --tf_debug
+
To build this tool from source, run:
$bazel build tensorflow/python/tools:saved_model_cli
@@ -87,6 +93,7 @@ from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
+from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.platform import app
from tensorflow.python.saved_model import loader
@@ -282,7 +289,7 @@ def get_signature_def_map(saved_model_dir, tag_set):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag):
+ overwrite_flag, tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -300,6 +307,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
it will be created.
overwrite_flag: A boolean flag to allow overwrite output file if file with
the same name exists.
+ tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
+ intermediate Tensor values and runtime GraphDefs while running the
+ SavedModel.
Raises:
RuntimeError: An error when output file already exists and overwrite is not
@@ -329,6 +339,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
with session.Session(graph=ops_lib.Graph()) as sess:
loader.load(sess, tag_set.split(','), saved_model_dir)
+ if tf_debug:
+ sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)
+
outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
for i, output in enumerate(outputs):
@@ -520,7 +533,7 @@ def run(args):
tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
- args.overwrite)
+ args.overwrite, tf_debug=args.tf_debug)
def create_parser():
@@ -620,6 +633,12 @@ def create_parser():
'--overwrite',
action='store_true',
help='if set, output file will be overwritten if it already exists.')
+ parser_run.add_argument(
+ '--tf_debug',
+ action='store_true',
+ help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
+ 'intermediate Tensors and runtime GraphDefs while running the '
+ 'SavedModel.')
parser_run.set_defaults(func=run)
return parser
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index b9d28794cc..c481dba2e9 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -28,6 +28,7 @@ import sys
import numpy as np
from six import StringIO
+from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.platform import test
from tensorflow.python.tools import saved_model_cli
@@ -299,9 +300,9 @@ Method name is: tensorflow/serving/predict"""
test.get_temp_dir()
])
saved_model_cli.run(args)
- y = np.load(output_file)
- y_exp = np.array([[3.5], [4.0]])
- self.assertTrue(np.allclose(y, y_exp))
+ y_actual = np.load(output_file)
+ y_expected = np.array([[3.5], [4.0]])
+ self.assertAllClose(y_expected, y_actual)
def testRunCommandNewOutdir(self):
self.parser = saved_model_cli.create_parser()
@@ -320,9 +321,9 @@ Method name is: tensorflow/serving/predict"""
output_dir
])
saved_model_cli.run(args)
- y = np.load(os.path.join(output_dir, 'y.npy'))
- y_exp = np.array([[2.5], [3.0]])
- self.assertTrue(np.allclose(y, y_exp))
+ y_actual = np.load(os.path.join(output_dir, 'y.npy'))
+ y_expected = np.array([[2.5], [3.0]])
+ self.assertAllClose(y_expected, y_actual)
def testRunCommandOutOverwrite(self):
self.parser = saved_model_cli.create_parser()
@@ -340,9 +341,9 @@ Method name is: tensorflow/serving/predict"""
test.get_temp_dir(), '--overwrite'
])
saved_model_cli.run(args)
- y = np.load(output_file)
- y_exp = np.array([[2.5], [3.0]])
- self.assertTrue(np.allclose(y, y_exp))
+ y_actual = np.load(output_file)
+ y_expected = np.array([[2.5], [3.0]])
+ self.assertAllClose(y_expected, y_actual)
def testRunCommandOutputFileExistError(self):
self.parser = saved_model_cli.create_parser()
@@ -362,6 +363,37 @@ Method name is: tensorflow/serving/predict"""
with self.assertRaises(RuntimeError):
saved_model_cli.run(args)
+ def testRunCommandWithDebuggerEnabled(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ x = np.array([[1], [2]])
+ x_notused = np.zeros((6, 3))
+ input_path = os.path.join(test.get_temp_dir(),
+ 'testRunCommandNewOutdir_inputs.npz')
+ output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+ if os.path.isdir(output_dir):
+ shutil.rmtree(output_dir)
+ np.savez(input_path, x0=x, x1=x_notused)
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
+ output_dir, '--tf_debug'
+ ])
+
+ def fake_wrapper_session(sess):
+ return sess
+
+ with test.mock.patch.object(local_cli_wrapper,
+ 'LocalCLIDebugWrapperSession',
+ side_effect=fake_wrapper_session,
+ autospec=True) as fake:
+ saved_model_cli.run(args)
+ fake.assert_called_with(test.mock.ANY)
+
+ y_actual = np.load(os.path.join(output_dir, 'y.npy'))
+ y_expected = np.array([[2.5], [3.0]])
+ self.assertAllClose(y_expected, y_actual)
+
if __name__ == '__main__':
test.main()