aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 17:42:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 17:51:07 -0700
commit6fcb16a7235362ebd36c7f8824fb9b2731234d34 (patch)
treee6a4e2c9216bdc0dd6661077f5de76b9d2f3ed4a /tensorflow/python/tools
parentbab7d5e3b57478309ca7829c2192ddc85fa4f5e1 (diff)
Add an option to connect to a worker.
PiperOrigin-RevId: 211013090
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/saved_model_cli.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 9b232865dd..6716c79f87 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -40,8 +40,8 @@ from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
@@ -140,7 +140,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -166,7 +166,7 @@ def _print_tensor_info(tensor_info, indent=0):
tensor_info: TensorInfo object to be printed.
indent: How far (in increments of 2 spaces) to indent each line output
"""
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -270,7 +270,7 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, tf_debug=False):
+ overwrite_flag, worker=None, tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -288,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
it will be created.
overwrite_flag: A boolean flag to allow overwrite output file if file with
the same name exists.
+ worker: If provided, the session will be run on the worker. Valid worker
+ specification is a bns or gRPC path.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -328,7 +330,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
for tensor_key in output_tensor_keys_sorted
]
- with session.Session(graph=ops_lib.Graph()) as sess:
+ with session.Session(worker, graph=ops_lib.Graph()) as sess:
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +634,8 @@ def run(args):
args.inputs, args.input_exprs, args.input_examples)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
- args.overwrite, tf_debug=args.tf_debug)
+ args.overwrite, worker=args.worker,
+ tf_debug=args.tf_debug)
def scan(args):
@@ -769,6 +772,12 @@ def create_parser():
help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
'intermediate Tensors and runtime GraphDefs while running the '
'SavedModel.')
+ parser_run.add_argument(
+ '--worker',
+ type=str,
+ default=None,
+ help='if specified, a Session will be run on the worker. '
+ 'Valid worker specification is a bns or gRPC path.')
parser_run.set_defaults(func=run)
# scan command