diff options
author | Vijay Vasudevan <vrv@google.com> | 2017-01-11 19:00:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-11 19:05:47 -0800 |
commit | 2bdf0b4e199ad18db41412a4164affe5d6d7b06f (patch) | |
tree | 27d18fa42f1ccfdccff9d1a784c93efb18424aa8 /tensorflow/tools/dist_test | |
parent | e11fae78b6812624f71f2551dd9f5c568ed014db (diff) |
Convert more flags use to argparse in dist_test
Change: 144278086
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r-- | tensorflow/tools/dist_test/python/census_widendeep.py | 84 | ||||
-rwxr-xr-x | tensorflow/tools/dist_test/server/grpc_tensorflow_server.py | 63 |
2 files changed, 105 insertions, 42 deletions
diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py index 62ff7f8f2f..db56a687f6 100644 --- a/tensorflow/tools/dist_test/python/census_widendeep.py +++ b/tensorflow/tools/dist_test/python/census_widendeep.py @@ -20,8 +20,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import json import os +import sys from six.moves import urllib import tensorflow as tf @@ -30,28 +32,6 @@ from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.contrib.learn.python.learn.estimators import run_config -# Define command-line flags -flags = tf.app.flags -flags.DEFINE_string("data_dir", "/tmp/census-data", - "Directory for storing the cesnsus data") -flags.DEFINE_string("model_dir", "/tmp/census_wide_and_deep_model", - "Directory for storing the model") -flags.DEFINE_string("output_dir", "", "Base output directory.") -flags.DEFINE_string("schedule", "local_run", - "Schedule to run for this experiment.") -flags.DEFINE_string("master_grpc_url", "", - "URL to master GRPC tensorflow server, e.g.," - "grpc://127.0.0.1:2222") -flags.DEFINE_integer("num_parameter_servers", 0, - "Number of parameter servers") -flags.DEFINE_integer("worker_index", 0, - "Worker index (>=0)") -flags.DEFINE_integer("train_steps", 1000, "Number of training steps") -flags.DEFINE_integer("eval_steps", 1, "Number of evaluation steps") - -FLAGS = flags.FLAGS - - # Constants: Data download URLs TRAIN_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data" TEST_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test" @@ -277,4 +257,62 @@ def main(unused_argv): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/census-data", + help="Directory for storing the cesnsus data" + ) + parser.add_argument( + "--model_dir", + type=str, + default="/tmp/census_wide_and_deep_model", + help="Directory for storing the model" + ) + parser.add_argument( + "--output_dir", + type=str, + default="", + help="Base output directory." + ) + parser.add_argument( + "--schedule", + type=str, + default="local_run", + help="Schedule to run for this experiment." + ) + parser.add_argument( + "--master_grpc_url", + type=str, + default="", + help="URL to master GRPC tensorflow server, e.g.,grpc://127.0.0.1:2222" + ) + parser.add_argument( + "--num_parameter_servers", + type=int, + default=0, + help="Number of parameter servers" + ) + parser.add_argument( + "--worker_index", + type=int, + default=0, + help="Worker index (>=0)" + ) + parser.add_argument( + "--train_steps", + type=int, + default=1000, + help="Number of training steps" + ) + parser.add_argument( + "--eval_steps", + type=int, + default=1, + help="Number of evaluation steps" + ) + global FLAGS # pylint:disable=global-at-module-level + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py index 5e36eaf748..2d774577b6 100755 --- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py +++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py @@ -33,32 +33,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.platform import app -from tensorflow.python.platform import flags from tensorflow.python.training import server_lib -FLAGS = flags.FLAGS - -flags.DEFINE_string("cluster_spec", "", """Cluster spec: SPEC. - SPEC is <JOB>(,<JOB>)*," - JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*," - NAME is a valid job name ([a-z][0-9a-z]*)," - HOST is a hostname or IP address," - PORT is a port number." -E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""") -flags.DEFINE_string("job_name", "", "Job name: e.g., local") -flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0") -flags.DEFINE_boolean("verbose", False, "Verbose mode") - -def parse_cluster_spec(cluster_spec, cluster): +def parse_cluster_spec(cluster_spec, cluster, verbose=False): """Parse content of cluster_spec string and inject info into cluster protobuf. Args: cluster_spec: cluster specification string, e.g., "local|localhost:2222;localhost:2223" cluster: cluster protobuf. + verbose: If verbose logging is requested. Raises: ValueError: if the cluster_spec string is invalid. @@ -82,7 +72,7 @@ def parse_cluster_spec(cluster_spec, cluster): job_def.name = job_name - if FLAGS.verbose: + if verbose: print("Added job named \"%s\"" % job_name) job_tasks = job_string.split("|")[1].split(";") @@ -92,7 +82,7 @@ def parse_cluster_spec(cluster_spec, cluster): job_def.tasks[i] = job_tasks[i] - if FLAGS.verbose: + if verbose: print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name)) @@ -101,7 +91,7 @@ def main(unused_args): server_def = tensorflow_server_pb2.ServerDef(protocol="grpc") # Cluster info - parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster) + parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose) # Job name if not FLAGS.job_name: @@ -121,4 +111,39 @@ def main(unused_args): if __name__ == "__main__": - app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--cluster_spec", + type=str, + default="", + help="""\ + Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is + <NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name + ([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a + port number." E.g., local|localhost:2222;localhost:2223, + ps|ps0:2222;ps1:2222\ + """ + ) + parser.add_argument( + "--job_name", + type=str, + default="", + help="Job name: e.g., local" + ) + parser.add_argument( + "--task_id", + type=int, + default=0, + help="Task index, e.g., 0" + ) + parser.add_argument( + "--verbose", + type="bool", + nargs="?", + const=True, + default=False, + help="Verbose mode" + ) + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) |