diff options
author | 2018-08-30 17:42:21 -0700 | |
---|---|---|
committer | 2018-08-30 17:51:07 -0700 | |
commit | 6fcb16a7235362ebd36c7f8824fb9b2731234d34 (patch) | |
tree | e6a4e2c9216bdc0dd6661077f5de76b9d2f3ed4a /tensorflow/python/tools | |
parent | bab7d5e3b57478309ca7829c2192ddc85fa4f5e1 (diff) |
Add an option to connect to a worker.
PiperOrigin-RevId: 211013090
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r-- | tensorflow/python/tools/saved_model_cli.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 9b232865dd..6716c79f87 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -40,8 +40,8 @@ from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import meta_graph as meta_graph_lib from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.platform import app # pylint: disable=unused-import from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import app # pylint: disable=unused-import from tensorflow.python.saved_model import loader from tensorflow.python.tools import saved_model_utils @@ -140,7 +140,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0): outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def( meta_graph_def, signature_def_key) - indent_str = " " * indent + indent_str = ' ' * indent def in_print(s): print(indent_str + s) @@ -166,7 +166,7 @@ def _print_tensor_info(tensor_info, indent=0): tensor_info: TensorInfo object to be printed. indent: How far (in increments of 2 spaces) to indent each line output """ - indent_str = " " * indent + indent_str = ' ' * indent def in_print(s): print(indent_str + s) @@ -270,7 +270,7 @@ def scan_meta_graph_def(meta_graph_def): def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, - overwrite_flag, tf_debug=False): + overwrite_flag, worker=None, tf_debug=False): """Runs SavedModel and fetch all outputs. Runs the input dictionary through the MetaGraphDef within a SavedModel @@ -288,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, it will be created. overwrite_flag: A boolean flag to allow overwrite output file if file with the same name exists. + worker: If provided, the session will be run on the worker. Valid worker + specification is a bns or gRPC path. tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the intermediate Tensor values and runtime GraphDefs while running the SavedModel. @@ -328,7 +330,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, for tensor_key in output_tensor_keys_sorted ] - with session.Session(graph=ops_lib.Graph()) as sess: + with session.Session(worker, graph=ops_lib.Graph()) as sess: loader.load(sess, tag_set.split(','), saved_model_dir) if tf_debug: @@ -632,7 +634,8 @@ def run(args): args.inputs, args.input_exprs, args.input_examples) run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, tensor_key_feed_dict, args.outdir, - args.overwrite, tf_debug=args.tf_debug) + args.overwrite, worker=args.worker, + tf_debug=args.tf_debug) def scan(args): @@ -769,6 +772,12 @@ def create_parser(): help='if set, will use TensorFlow Debugger (tfdbg) to watch the ' 'intermediate Tensors and runtime GraphDefs while running the ' 'SavedModel.') + parser_run.add_argument( + '--worker', + type=str, + default=None, + help='if specified, a Session will be run on the worker. ' + 'Valid worker specification is a bns or gRPC path.') parser_run.set_defaults(func=run) # scan command |