aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 13:07:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:11:43 -0700
commit3648cb0198690d551ea5c8eefcf706c8fa67f4f0 (patch)
tree8939c7c6f97bc1b2221090a8157a84096c548b38 /tensorflow/python/tools
parent5c8c48df7fd4ccbe4a9dec035fdec6b02a5d6016 (diff)
Add option to initialize the TPU system.
PiperOrigin-RevId: 215266241
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/saved_model_cli.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 3dbccd1409..2fcb0fa029 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -267,7 +267,8 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, worker=None, tf_debug=False):
+ overwrite_flag, worker=None, init_tpu=False,
+ tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -287,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
the same name exists.
worker: If provided, the session will be run on the worker. Valid worker
specification is a bns or gRPC path.
+ init_tpu: If true, the TPU system will be initialized after the session
+ is created.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -328,6 +331,12 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
]
with session.Session(worker, graph=ops_lib.Graph()) as sess:
+ if init_tpu:
+ print('Initializing TPU System ...')
+ # This is needed for freshly started worker, or if the job
+ # restarts after a preemption.
+ sess.run(tf.contrib.tpu.initialize_system())
+
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +641,7 @@ def run(args):
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
args.overwrite, worker=args.worker,
- tf_debug=args.tf_debug)
+ init_tpu=args.init_tpu, tf_debug=args.tf_debug)
def scan(args):
@@ -775,6 +784,12 @@ def create_parser():
default=None,
help='if specified, a Session will be run on the worker. '
'Valid worker specification is a bns or gRPC path.')
+ parser_run.add_argument(
+ '--init_tpu',
+ action='store_true',
+ default=None,
+ help='if specified, tpu.initialize_system will be called on the Session. '
+ 'This option should be only used if the worker is a TPU job.')
parser_run.set_defaults(func=run)
# scan command