aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-01-11 19:00:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 19:05:47 -0800
commit2bdf0b4e199ad18db41412a4164affe5d6d7b06f (patch)
tree27d18fa42f1ccfdccff9d1a784c93efb18424aa8 /tensorflow/tools/dist_test
parente11fae78b6812624f71f2551dd9f5c568ed014db (diff)
Convert more flags use to argparse in dist_test
Change: 144278086
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/python/census_widendeep.py84
-rwxr-xr-xtensorflow/tools/dist_test/server/grpc_tensorflow_server.py63
2 files changed, 105 insertions, 42 deletions
diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py
index 62ff7f8f2f..db56a687f6 100644
--- a/tensorflow/tools/dist_test/python/census_widendeep.py
+++ b/tensorflow/tools/dist_test/python/census_widendeep.py
@@ -20,8 +20,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import json
import os
+import sys
from six.moves import urllib
import tensorflow as tf
@@ -30,28 +32,6 @@ from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config
-# Define command-line flags
-flags = tf.app.flags
-flags.DEFINE_string("data_dir", "/tmp/census-data",
- "Directory for storing the cesnsus data")
-flags.DEFINE_string("model_dir", "/tmp/census_wide_and_deep_model",
- "Directory for storing the model")
-flags.DEFINE_string("output_dir", "", "Base output directory.")
-flags.DEFINE_string("schedule", "local_run",
- "Schedule to run for this experiment.")
-flags.DEFINE_string("master_grpc_url", "",
- "URL to master GRPC tensorflow server, e.g.,"
- "grpc://127.0.0.1:2222")
-flags.DEFINE_integer("num_parameter_servers", 0,
- "Number of parameter servers")
-flags.DEFINE_integer("worker_index", 0,
- "Worker index (>=0)")
-flags.DEFINE_integer("train_steps", 1000, "Number of training steps")
-flags.DEFINE_integer("eval_steps", 1, "Number of evaluation steps")
-
-FLAGS = flags.FLAGS
-
-
# Constants: Data download URLs
TRAIN_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data"
TEST_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test"
@@ -277,4 +257,62 @@ def main(unused_argv):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="/tmp/census-data",
+ help="Directory for storing the cesnsus data"
+ )
+ parser.add_argument(
+ "--model_dir",
+ type=str,
+ default="/tmp/census_wide_and_deep_model",
+ help="Directory for storing the model"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="",
+ help="Base output directory."
+ )
+ parser.add_argument(
+ "--schedule",
+ type=str,
+ default="local_run",
+ help="Schedule to run for this experiment."
+ )
+ parser.add_argument(
+ "--master_grpc_url",
+ type=str,
+ default="",
+ help="URL to master GRPC tensorflow server, e.g.,grpc://127.0.0.1:2222"
+ )
+ parser.add_argument(
+ "--num_parameter_servers",
+ type=int,
+ default=0,
+ help="Number of parameter servers"
+ )
+ parser.add_argument(
+ "--worker_index",
+ type=int,
+ default=0,
+ help="Worker index (>=0)"
+ )
+ parser.add_argument(
+ "--train_steps",
+ type=int,
+ default=1000,
+ help="Number of training steps"
+ )
+ parser.add_argument(
+ "--eval_steps",
+ type=int,
+ default=1,
+ help="Number of evaluation steps"
+ )
+ global FLAGS # pylint:disable=global-at-module-level
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
index 5e36eaf748..2d774577b6 100755
--- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -33,32 +33,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
from tensorflow.python.training import server_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("cluster_spec", "", """Cluster spec: SPEC.
- SPEC is <JOB>(,<JOB>)*,"
- JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
- NAME is a valid job name ([a-z][0-9a-z]*),"
- HOST is a hostname or IP address,"
- PORT is a port number."
-E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""")
-flags.DEFINE_string("job_name", "", "Job name: e.g., local")
-flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
-flags.DEFINE_boolean("verbose", False, "Verbose mode")
-
-def parse_cluster_spec(cluster_spec, cluster):
+def parse_cluster_spec(cluster_spec, cluster, verbose=False):
"""Parse content of cluster_spec string and inject info into cluster protobuf.
Args:
cluster_spec: cluster specification string, e.g.,
"local|localhost:2222;localhost:2223"
cluster: cluster protobuf.
+ verbose: If verbose logging is requested.
Raises:
ValueError: if the cluster_spec string is invalid.
@@ -82,7 +72,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def.name = job_name
- if FLAGS.verbose:
+ if verbose:
print("Added job named \"%s\"" % job_name)
job_tasks = job_string.split("|")[1].split(";")
@@ -92,7 +82,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def.tasks[i] = job_tasks[i]
- if FLAGS.verbose:
+ if verbose:
print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))
@@ -101,7 +91,7 @@ def main(unused_args):
server_def = tensorflow_server_pb2.ServerDef(protocol="grpc")
# Cluster info
- parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
+ parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose)
# Job name
if not FLAGS.job_name:
@@ -121,4 +111,39 @@ def main(unused_args):
if __name__ == "__main__":
- app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--cluster_spec",
+ type=str,
+ default="",
+ help="""\
+ Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is
+ <NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name
+ ([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a
+ port number." E.g., local|localhost:2222;localhost:2223,
+ ps|ps0:2222;ps1:2222\
+ """
+ )
+ parser.add_argument(
+ "--job_name",
+ type=str,
+ default="",
+ help="Job name: e.g., local"
+ )
+ parser.add_argument(
+ "--task_id",
+ type=int,
+ default=0,
+ help="Task index, e.g., 0"
+ )
+ parser.add_argument(
+ "--verbose",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Verbose mode"
+ )
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)