From 3648cb0198690d551ea5c8eefcf706c8fa67f4f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 1 Oct 2018 13:07:12 -0700 Subject: Add option to initialize the TPU system. PiperOrigin-RevId: 215266241 --- tensorflow/python/tools/saved_model_cli.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) (limited to 'tensorflow/python/tools') diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 3dbccd1409..2fcb0fa029 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -267,7 +267,8 @@ def scan_meta_graph_def(meta_graph_def): def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, - overwrite_flag, worker=None, tf_debug=False): + overwrite_flag, worker=None, init_tpu=False, + tf_debug=False): """Runs SavedModel and fetch all outputs. Runs the input dictionary through the MetaGraphDef within a SavedModel @@ -287,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, the same name exists. worker: If provided, the session will be run on the worker. Valid worker specification is a bns or gRPC path. + init_tpu: If true, the TPU system will be initialized after the session + is created. tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the intermediate Tensor values and runtime GraphDefs while running the SavedModel. @@ -328,6 +331,12 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, ] with session.Session(worker, graph=ops_lib.Graph()) as sess: + if init_tpu: + print('Initializing TPU System ...') + # This is needed for freshly started worker, or if the job + # restarts after a preemption. + sess.run(tf.contrib.tpu.initialize_system()) + loader.load(sess, tag_set.split(','), saved_model_dir) if tf_debug: @@ -632,7 +641,7 @@ def run(args): run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, tensor_key_feed_dict, args.outdir, args.overwrite, worker=args.worker, - tf_debug=args.tf_debug) + init_tpu=args.init_tpu, tf_debug=args.tf_debug) def scan(args): @@ -775,6 +784,12 @@ def create_parser(): default=None, help='if specified, a Session will be run on the worker. ' 'Valid worker specification is a bns or gRPC path.') + parser_run.add_argument( + '--init_tpu', + action='store_true', + default=None, + help='if specified, tpu.initialize_system will be called on the Session. ' + 'This option should be only used if the worker is a TPU job.') parser_run.set_defaults(func=run) # scan command -- cgit v1.2.3