aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-26 16:53:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 16:59:01 -0800
commitaee7f95a027accc94f1f9130f0cfaecd9399bc1d (patch)
tree6b8484915bf631f18b2fa0561a73549d9bf19fad /tensorflow/tools/dist_test
parente95537708f070a98607393a8f60bc61f1611a77b (diff)
Add C0301 line-too-long error to pylint sanity check.
PiperOrigin-RevId: 183467186
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/python/mnist_replica.py32
1 files changed, 13 insertions, 19 deletions
diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py
index e40ecb43f9..a2d12442c4 100644
--- a/tensorflow/tools/dist_test/python/mnist_replica.py
+++ b/tensorflow/tools/dist_test/python/mnist_replica.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Distributed MNIST training and validation, with model replicas.
A simple softmax model with one hidden layer is defined. The parameters
@@ -32,7 +31,6 @@ perform forward computation and gradient calculation in parallel, which
should lead to increased training speed for the simple model.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -45,7 +43,6 @@ import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
-
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
@@ -56,8 +53,7 @@ flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
-flags.DEFINE_integer("num_gpus", 1,
- "Total number of gpus for each machine."
+flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
@@ -69,24 +65,24 @@ flags.DEFINE_integer("train_steps", 200,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
-flags.DEFINE_boolean("sync_replicas", False,
- "Use the sync_replicas (synchronized replicas) mode, "
- "wherein the parameter updates from workers are aggregated "
- "before applied to avoid stale gradients")
+flags.DEFINE_boolean(
+ "sync_replicas", False,
+ "Use the sync_replicas (synchronized replicas) mode, "
+ "wherein the parameter updates from workers are aggregated "
+ "before applied to avoid stale gradients")
flags.DEFINE_boolean(
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"server.")
-flags.DEFINE_string("ps_hosts","localhost:2222",
+flags.DEFINE_string("ps_hosts", "localhost:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
-flags.DEFINE_string("job_name", None,"job name: worker or ps")
+flags.DEFINE_string("job_name", None, "job name: worker or ps")
FLAGS = flags.FLAGS
-
IMAGE_PIXELS = 28
@@ -97,7 +93,7 @@ def main(unused_argv):
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
- if FLAGS.task_index is None or FLAGS.task_index =="":
+ if FLAGS.task_index is None or FLAGS.task_index == "":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
@@ -110,9 +106,7 @@ def main(unused_argv):
# Get the number of workers.
num_workers = len(worker_spec)
- cluster = tf.train.ClusterSpec({
- "ps": ps_spec,
- "worker": worker_spec})
+ cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
@@ -217,7 +211,8 @@ def main(unused_argv):
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
- device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
+ device_filters=["/job:ps",
+ "/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
@@ -231,8 +226,7 @@ def main(unused_argv):
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
- sess = sv.prepare_or_wait_for_session(server_grpc_url,
- config=sess_config)
+ sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)