aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/saved_model_cli.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 3dbccd1409..2fcb0fa029 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -267,7 +267,8 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, worker=None, tf_debug=False):
+ overwrite_flag, worker=None, init_tpu=False,
+ tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -287,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
the same name exists.
worker: If provided, the session will be run on the worker. Valid worker
specification is a bns or gRPC path.
+ init_tpu: If true, the TPU system will be initialized after the session
+ is created.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -328,6 +331,12 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
]
with session.Session(worker, graph=ops_lib.Graph()) as sess:
+ if init_tpu:
+ print('Initializing TPU System ...')
+ # This is needed for freshly started worker, or if the job
+ # restarts after a preemption.
+ sess.run(tf.contrib.tpu.initialize_system())
+
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +641,7 @@ def run(args):
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
args.overwrite, worker=args.worker,
- tf_debug=args.tf_debug)
+ init_tpu=args.init_tpu, tf_debug=args.tf_debug)
def scan(args):
@@ -775,6 +784,12 @@ def create_parser():
default=None,
help='if specified, a Session will be run on the worker. '
'Valid worker specification is a bns or gRPC path.')
+ parser_run.add_argument(
+ '--init_tpu',
+ action='store_true',
+ default=None,
+ help='if specified, tpu.initialize_system will be called on the Session. '
+ 'This option should be only used if the worker is a TPU job.')
parser_run.set_defaults(func=run)
# scan command