aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py')
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py26
1 files changed, 19 insertions, 7 deletions
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 7970c20a26..78d237e6a2 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -17,6 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl import flags
import os
import subprocess
@@ -24,13 +25,21 @@ import sys
import tensorflow as tf
-tf.flags.DEFINE_string('service_addr', '',
- 'Address of TPU profiler service e.g. localhost:8466')
-tf.flags.DEFINE_string('logdir', '',
- 'Path of TensorBoard log directory e.g. /tmp/tb_log')
-tf.flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.')
+flags.DEFINE_string(
+ 'service_addr', None, 'Address of TPU profiler service e.g. '
+ 'localhost:8466')
+flags.DEFINE_string(
+ 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
+ 'gs://tb_bucket')
+flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.')
+flags.DEFINE_integer(
+ 'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
+ 'event is collected.')
+flags.DEFINE_boolean(
+ 'include_dataset_ops', True, 'Set to false to profile longer TPU '
+ 'device traces.')
-FLAGS = tf.flags.FLAGS
+FLAGS = flags.FLAGS
EXECUTABLE = 'data/capture_tpu_profile'
@@ -42,10 +51,13 @@ def main(unused_argv=None):
if not FLAGS.service_addr or not FLAGS.logdir:
sys.exit('service_addr and logdir must be provided.')
executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE)
+ logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
cmd = [executable_path]
- cmd.append('--logdir='+FLAGS.logdir)
+ cmd.append('--logdir='+logdir)
cmd.append('--service_addr='+FLAGS.service_addr)
cmd.append('--duration_ms='+str(FLAGS.duration_ms))
+ cmd.append('--num_tracing_attempts='+str(FLAGS.num_tracing_attempts))
+ cmd.append('--include_dataset_ops='+str(FLAGS.include_dataset_ops).lower())
subprocess.call(cmd)