aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jianmin Chen <jmchen@google.com>2016-10-07 12:53:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 14:03:39 -0700
commitecdee38a534133ecd7ba18e58527cc4120277190 (patch)
tree5b76e2e8a3038cb3b11539121360c062c2719154 /tensorflow
parent2c8d270735176df1a59b5a80885b2e14b4f06953 (diff)
Switch to the new accumulators in the sync_rep optimizer (currently called V2). Please note that the gradients from replicas are now averaged instead of summed (as in the old sync_replicas_optimizer) so you need to increase the learning rate according to the number of replicas. This change is introduced to be consistent with how gradients are aggregated (averaged) within a batch in a replica.
As shown in the code change, the switch results in: 1. much cleaner and simpler code. 2. much more efficient and reliable staleness check. It is now 100% strict with no extra contention to PS servers. 3. no need for clean_up op so we can get rid of the abort_op which can confuse users. 4. number of replicas can be changed without complaints from checkpoint as the local_step is now just a local variable instead of a global vector variable. This has been tried with manual restarts of workers (chief or non chief) and ps and seems to be quite robust. Change: 135513399
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/BUILD9
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py414
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py279
-rw-r--r--tensorflow/python/training/training.py1
4 files changed, 703 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 18105bd1ec..06e36dc3ae 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1876,6 +1876,15 @@ cuda_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)
+cuda_py_test(
+ name = "sync_replicas_optimizer_test",
+ size = "medium",
+ srcs = [
+ "training/sync_replicas_optimizer_test.py",
+ ],
+ additional_deps = ["//tensorflow:tensorflow_py"],
+)
+
py_library(
name = "timeline",
srcs = ["client/timeline.py"],
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index ba07cd5908..24d8177f7d 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -31,6 +31,416 @@ from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
+# Please note that the gradients from replicas are averaged instead of summed
+# (as in the old sync_replicas_optimizer) so you need to increase the learning
+# rate according to the number of replicas. This change is introduced to be
+# consistent with how gradients are aggregated (averaged) within a batch in a
+# replica.
+class SyncReplicasOptimizerV2(optimizer.Optimizer):
+ """Class to synchronize, aggregate gradients and pass them to the optimizer.
+
+ In a typical asynchronous training environment, it's common to have some
+ stale gradients. For example, with a N-replica asynchronous training,
+ gradients will be applied to the variables N times independently. Depending
+ on each replica's training speed, some gradients might be calculated from
+ copies of the variable from several steps back (N-1 steps on average). This
+ optimizer avoids stale gradients by collecting gradients from all replicas,
+ averaging them, then applying them to the variables in one shot, after
+ which replicas can fetch the new variables and continue.
+
+ The following accumulators/queue are created:
+ <empty line>
+ * N `gradient accumulators`, one per variable to train. Gradients are pushed
+ to them and the chief worker will wait until enough gradients are collected
+ and then average them before applying to variables. The accumulator will
+ drop all stale gradients (more details in the accumulator op).
+ * 1 `token` queue where the optimizer pushes the new global_step value after
+ all variables are updated.
+
+ The following local variable is created:
+ * `sync_rep_local_step`, one per replica. Compared against the global_step in
+ each accumulator to check for staleness of the gradients.
+
+ The optimizer adds nodes to the graph to collect gradients and pause the
+ trainers until variables are updated.
+ For the Parameter Server job:
+ <empty line>
+ 1. An accumulator is created for each variable, and each replica pushes the
+ gradients into the accumulators instead of directly applying them to the
+ variables.
+ 2. Each accumulator averages once enough gradients (replicas_to_aggregate)
+ have been accumulated.
+ 3. Apply the averaged gradients to the variables.
+ 4. Only after all variables have been updated, increment the global step.
+ 5. Only after step 4, pushes `global_step` in the `token_queue`, once for
+ each worker replica. The workers can now fetch the global step, use it to
+ update its local_step variable and start the next batch.
+
+ For the replicas:
+ <empty line>
+ 1. Start a step: fetch variables and compute gradients.
+ 2. Once the gradients have been computed, push them into gradient
+ accumulators. Each accumulator will check the staleness and drop the stale.
+ 3. After pushing all the gradients, dequeue an updated value of global_step
+ from the token queue and record that step to its local_step variable. Note
+ that this is effectively a barrier.
+ 4. Start the next batch.
+
+ ### Usage
+
+ ```python
+ # Create any optimizer to update the variables, say a simple SGD:
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each
+ # step the optimizer collects 50 gradients before applying to variables.
+ # Note that if you want to have 2 backup replicas, you can change
+ # total_num_replicas=52 and make sure this number matches how many physical
+ # replicas you started in your job.
+ opt = tf.SyncReplicasOptimizerV2(opt, replicas_to_aggregate=50,
+ total_num_replicas=50)
+
+ # Some models have startup_delays to help stabilize the model but when using
+ # sync_replicas training, set it to 0.
+
+ # Now you can call `minimize()` or `compute_gradients()` and
+ # `apply_gradients()` normally
+ grads = opt.minimize(total_loss, global_step=self.global_step)
+
+
+ # You can now call get_init_tokens_op() and get_chief_queue_runner().
+ # Note that get_init_tokens_op() must be called before creating session
+ # because it modifies the graph by adding new nodes.
+ init_token_op = opt.get_init_tokens_op()
+ chief_queue_runner = opt.get_chief_queue_runner()
+ ```
+
+ In the training program, every worker will run the train_op as if not
+ synchronized. But one worker (usually the chief) will need to execute the
+ chief_queue_runner and get_init_tokens_op from this optimizer.
+
+ ```python
+ # When you create the supervisor, you need to add the local_init_op and
+ # ready_for_local_init_op to make sure the local_step is initialized to the
+ # global_step. Here is an example:
+ sv = tf.Supervisor(graph=g,
+ is_chief=is_chief,
+ # This initialize local step.
+ local_init_op=local_init_op,
+ # This makes sure global step is initialized before using.
+ ready_for_local_init_op=ready_for_local_init_op,
+ saver=model.saver)
+
+ # After the session is created by the Supervisor and before the main while
+ # loop:
+ if is_chief and FLAGS.sync_replicas:
+ sv.start_queue_runners(sess, [chief_queue_runner])
+ # Insert initial tokens to the queue.
+ sess.run(init_token_op)
+ ```
+
+ @@__init__
+ @@compute_gradients
+ @@apply_gradients
+ @@get_chief_queue_runner
+ @@get_init_tokens_op
+ """
+
+ def __init__(self,
+ opt,
+ replicas_to_aggregate,
+ total_num_replicas=None,
+ variable_averages=None,
+ variables_to_average=None,
+ use_locking=False,
+ name="sync_replicas"):
+ """Construct a sync_replicas optimizer.
+
+ Args:
+ opt: The actual optimizer that will be used to compute and apply the
+ gradients. Must be one of the Optimizer classes.
+ replicas_to_aggregate: number of replicas to aggregate for each variable
+ update.
+ total_num_replicas: Total number of tasks/workers/replicas, could be
+ different from replicas_to_aggregate.
+ If total_num_replicas > replicas_to_aggregate: it is backup_replicas +
+ replicas_to_aggregate.
+ If total_num_replicas < replicas_to_aggregate: Replicas compute
+ multiple batches per update to variables.
+ variable_averages: Optional `ExponentialMovingAverage` object, used to
+ maintain moving averages for the variables passed in
+ `variables_to_average`.
+ variables_to_average: a list of variables that need to be averaged. Only
+ needed if variable_averages is passed in.
+ use_locking: If True use locks for update operation.
+ name: string. Optional name of the returned operation.
+ """
+ if total_num_replicas is None:
+ total_num_replicas = replicas_to_aggregate
+
+ super(SyncReplicasOptimizerV2, self).__init__(use_locking, name)
+ logging.info(
+ "SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s",
+ replicas_to_aggregate, total_num_replicas)
+ self._opt = opt
+ self._replicas_to_aggregate = replicas_to_aggregate
+ self._gradients_applied = False
+ self._variable_averages = variable_averages
+ self._variables_to_average = variables_to_average
+ self._total_num_replicas = total_num_replicas
+ self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate)
+ self._global_step = None
+ self._sync_token_queue = None
+
+ # The synchronization op will be executed in a queue runner which should
+ # only be executed by one of the replicas (usually the chief).
+ self._chief_queue_runner = None
+
+ # Remember which accumulator is on which device to set the initial step in
+ # the accumulator to be global step. This list contains list of the
+ # following format: (accumulator, device).
+ self._accumulator_list = []
+
+ def compute_gradients(self, *args, **kwargs):
+ """Compute gradients of "loss" for the variables in "var_list".
+
+ This simply wraps the compute_gradients() from the real optimizer. The
+ gradients will be aggregated in the apply_gradients() so that user can
+ modify the gradients like clipping with per replica global norm if needed.
+ The global norm with aggregated gradients can be bad as one replica's huge
+ gradients can hurt the gradients from other replicas.
+
+ Args:
+ *args: Arguments for compute_gradients().
+ **kwargs: Keyword arguments for compute_gradients().
+
+ Returns:
+ A list of (gradient, variable) pairs.
+ """
+ return self._opt.compute_gradients(*args, **kwargs)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This contains most of the synchronization implementation and also wraps the
+ apply_gradients() from the real optimizer.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ compute_gradients().
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the Optimizer constructor.
+
+ Returns:
+ train_op: The op to dequeue a token so the replicas can exit this batch
+ and start the next one. This is executed by each replica.
+
+ Raises:
+ ValueError: If the grads_and_vars is empty.
+ ValueError: If global step is not provided, the staleness cannot be
+ checked.
+ """
+ if not grads_and_vars:
+ raise ValueError("Must supply at least one variable")
+
+ if global_step is None:
+ raise ValueError("Global step is required to check staleness")
+
+ self._global_step = global_step
+ train_ops = []
+ aggregated_grad = []
+ var_list = []
+
+ self._local_step = variables.Variable(
+ initial_value=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name="sync_rep_local_step")
+ self.local_step_init_op = state_ops.assign(self._local_step, global_step)
+ chief_init_ops = [self.local_step_init_op]
+ self.ready_for_local_init_op = variables.report_uninitialized_variables(
+ variables.all_variables())
+
+ with ops.name_scope(None, self._name):
+ for grad, var in grads_and_vars:
+ var_list.append(var)
+ with ops.device(var.device):
+ # Dense gradients.
+ if grad is None:
+ aggregated_grad.append(None) # pass-through.
+ continue
+ elif isinstance(grad, ops.Tensor):
+ grad_accum = data_flow_ops.ConditionalAccumulator(
+ grad.dtype,
+ shape=var.get_shape(),
+ shared_name=var.name + "/grad_accum")
+ train_ops.append(grad_accum.apply_grad(
+ grad, local_step=self._local_step))
+ aggregated_grad.append(grad_accum.take_grad(
+ self._replicas_to_aggregate))
+ else:
+ if not isinstance(grad, ops.IndexedSlices):
+ raise ValueError("Unknown grad type!")
+ grad_accum = data_flow_ops.SparseConditionalAccumulator(
+ grad.dtype, shape=(), shared_name=var.name + "/grad_accum")
+ train_ops.append(grad_accum.apply_indexed_slices_grad(
+ grad, local_step=self._local_step))
+ aggregated_grad.append(grad_accum.take_indexed_slices_grad(
+ self._replicas_to_aggregate))
+
+ self._accumulator_list.append((grad_accum, var.device))
+
+ aggregated_grads_and_vars = zip(aggregated_grad, var_list)
+
+ # sync_op will be assigned to the same device as the global step.
+ with ops.device(global_step.device), ops.name_scope(""):
+ update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
+ global_step)
+
+ # Create token queue.
+ with ops.device(global_step.device), ops.name_scope(""):
+ sync_token_queue = (
+ data_flow_ops.FIFOQueue(-1,
+ global_step.dtype.base_dtype,
+ shapes=(),
+ shared_name="sync_token_q"))
+ self._sync_token_queue = sync_token_queue
+
+ # dummy_queue is passed to the queue runner. Don't use the real queues
+ # because the queue runner doesn't automatically reopen it once it
+ # closed queues in PS devices.
+ dummy_queue = (
+ data_flow_ops.FIFOQueue(1,
+ types_pb2.DT_INT32,
+ shapes=(),
+ shared_name="dummy_queue"))
+
+ with ops.device(global_step.device), ops.name_scope(""):
+ # Replicas have to wait until they can get a token from the token queue.
+ with ops.control_dependencies(train_ops):
+ token = sync_token_queue.dequeue()
+ train_op = state_ops.assign(self._local_step, token)
+
+ with ops.control_dependencies([update_op]):
+ # Sync_op needs to insert tokens to the token queue at the end of the
+ # step so the replicas can fetch them to start the next step.
+ tokens = array_ops.fill([self._tokens_per_step], global_step.ref())
+ sync_op = sync_token_queue.enqueue_many((tokens,))
+
+ if self._variable_averages is not None:
+ with ops.control_dependencies([sync_op]), ops.name_scope(""):
+ sync_op = self._variable_averages.apply(
+ self._variables_to_average)
+
+ self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
+ [sync_op])
+ for accum, dev in self._accumulator_list:
+ with ops.device(dev):
+ chief_init_ops.append(
+ accum.set_global_step(
+ global_step, name="SetGlobalStep"))
+ self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
+ self._gradients_applied = True
+ return train_op
+
+ def get_chief_queue_runner(self):
+ """Returns the QueueRunner for the chief to execute.
+
+ This includes the operations to synchronize replicas: aggregate gradients,
+ apply to variables, increment global step, insert tokens to token queue.
+
+ Note that this can only be called after calling apply_gradients() which
+ actually generates this queuerunner.
+
+ Returns:
+ A `QueueRunner` for chief to execute.
+
+ Raises:
+ ValueError: If this is called before apply_gradients().
+ """
+ if self._gradients_applied is False:
+ raise ValueError("Should be called after apply_gradients().")
+
+ return self._chief_queue_runner
+
+ def get_slot(self, *args, **kwargs):
+ """Return a slot named "name" created for "var" by the Optimizer.
+
+ This simply wraps the get_slot() from the actual optimizer.
+
+ Args:
+ *args: Arguments for get_slot().
+ **kwargs: Keyword arguments for get_slot().
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ return self._opt.get_slot(*args, **kwargs)
+
+ def get_slot_names(self, *args, **kwargs):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ This simply wraps the get_slot_names() from the actual optimizer.
+
+ Args:
+ *args: Arguments for get_slot().
+ **kwargs: Keyword arguments for get_slot().
+
+ Returns:
+ A list of strings.
+ """
+ return self._opt.get_slot_names(*args, **kwargs)
+
+ def get_init_tokens_op(self, num_tokens=-1):
+ """Returns the op to fill the sync_token_queue with the tokens.
+
+ This is supposed to be executed in the beginning of the chief/sync thread
+ so that even if the total_num_replicas is less than replicas_to_aggregate,
+ the model can still proceed as the replicas can compute multiple steps per
+ variable update. Make sure:
+ `num_tokens >= replicas_to_aggregate - total_num_replicas`.
+
+ Args:
+ num_tokens: Number of tokens to add to the queue.
+
+ Returns:
+ An op for the chief/sync replica to fill the token queue.
+
+ Raises:
+ ValueError: If this is called before apply_gradients().
+ ValueError: If num_tokens are smaller than replicas_to_aggregate -
+ total_num_replicas.
+ """
+ if self._gradients_applied is False:
+ raise ValueError(
+ "get_init_tokens_op() should be called after apply_gradients().")
+
+ tokens_needed = self._replicas_to_aggregate - self._total_num_replicas
+ if num_tokens == -1:
+ num_tokens = self._replicas_to_aggregate
+ elif num_tokens < tokens_needed:
+ raise ValueError(
+ "Too few tokens to finish the first step: %d (given) vs %d (needed)" %
+ (num_tokens, tokens_needed))
+
+ if num_tokens > 0:
+ with ops.device(self._global_step.device), ops.name_scope(""):
+ tokens = array_ops.fill([num_tokens],
+ self._global_step.ref())
+ init_tokens = self._sync_token_queue.enqueue_many((tokens,))
+ else:
+ init_tokens = control_flow_ops.no_op(name="no_init_tokens")
+
+ return init_tokens
+
+
+# Please switch to v2 if you are still using the old sync optimizer. V2
+# is much more efficient and stable. It also removed 100% of the stale
+# gradients which is not possible in this implementation without significant
+# overhead. This is kept here just for backward compatibility and will be
+# DEPRECATED later.
class SyncReplicasOptimizer(optimizer.Optimizer):
"""Class to synchronize, aggregate gradients and pass them to the optimizer.
@@ -170,6 +580,10 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
total_num_replicas = replicas_to_aggregate
super(SyncReplicasOptimizer, self).__init__(use_locking, name)
+ logging.info("""TO BE DEPRECATED!!!
+ This version will be deprecated. Please switch to V2 at your
+ earliest convenience.""")
+
logging.info(
"SyncReplicas enabled: replicas_to_aggregate=%s; total_num_replicas=%s",
replicas_to_aggregate, total_num_replicas)
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
new file mode 100644
index 0000000000..e340a22374
--- /dev/null
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -0,0 +1,279 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for sync_replicas_optimizer.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.util import net_lib
+
+
+def create_local_cluster(num_workers, num_ps, protocol="grpc"):
+ """Create local GRPC servers and return them."""
+ worker_ports = [net_lib.pick_unused_port_or_die() for _ in range(num_workers)]
+ ps_ports = [net_lib.pick_unused_port_or_die() for _ in range(num_ps)]
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]}
+ cs = tf.train.ClusterSpec(cluster_dict)
+
+ workers = [
+ tf.train.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)]
+ ps_servers = [
+ tf.train.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)]
+
+ return workers, ps_servers
+
+
+# Creates the workers and return their sessions, graphs, train_ops.
+def get_workers(num_workers, replicas_to_aggregate, workers):
+ sessions = []
+ graphs = []
+ train_ops = []
+ for worker_id in range(num_workers):
+ graph = tf.Graph()
+ is_chief = (worker_id == 0)
+ with graph.as_default():
+ with tf.device("/job:ps/task:0"):
+ global_step = tf.Variable(0, name="global_step", trainable=False)
+ var_0 = tf.Variable(0.0, name="v0")
+ with tf.device("/job:ps/task:1"):
+ var_1 = tf.Variable(1.0, name="v1")
+ var_sparse = tf.Variable([[3.0], [4.0]], name="v_sparse")
+
+ with tf.device("/job:worker/task:"+str(worker_id)):
+ grads_0 = tf.constant(0.1+worker_id*0.2)
+ grads_1 = tf.constant(0.9+worker_id*0.2)
+ # This is to test against sparse gradients.
+ grads_sparse = tf.IndexedSlices(
+ tf.constant([0.1+worker_id*0.2], shape=[1, 1]),
+ tf.constant([1], dtype=tf.int64),
+ tf.constant([2, 1], dtype=tf.int64))
+ sgd_opt = tf.train.GradientDescentOptimizer(2.0)
+ sync_rep_opt = tf.train.SyncReplicasOptimizerV2(
+ sgd_opt, replicas_to_aggregate=replicas_to_aggregate,
+ total_num_replicas=num_workers)
+ train_op = [sync_rep_opt.apply_gradients(
+ zip([grads_0, grads_1, grads_sparse], [var_0, var_1, var_sparse]),
+ global_step=global_step)]
+
+ init_op = tf.initialize_all_variables()
+ # Needed ops from the sync_rep optimizer. This is mainly for the
+ # local_step initialization.
+ local_init_op = sync_rep_opt.local_step_init_op
+ if is_chief:
+ local_init_op = sync_rep_opt.chief_init_op
+ ready_for_local_init_op = sync_rep_opt.ready_for_local_init_op
+
+ # Chief_queue_runner
+ chief_queue_runner = sync_rep_opt.get_chief_queue_runner()
+ sync_init_op = sync_rep_opt.get_init_tokens_op(num_workers)
+
+ # Creates session for chief.
+ supervisor = tf.train.Supervisor(
+ graph=graph,
+ is_chief=is_chief,
+ recovery_wait_secs=1,
+ init_op=init_op,
+ local_init_op=local_init_op,
+ ready_for_local_init_op=ready_for_local_init_op)
+ session = supervisor.prepare_or_wait_for_session(workers[worker_id].target)
+
+ # Chief should execute the sync_init_op and start the chief queue runner.
+ if is_chief:
+ session.run(sync_init_op)
+ supervisor.StartQueueRunners(session, [chief_queue_runner])
+
+ sessions.append(session)
+ graphs.append(graph)
+ train_ops.append(train_op)
+
+ return sessions, graphs, train_ops
+
+
+class SyncReplicasOptimizerV2Test(tf.test.TestCase):
+
+ def _run(self, train_op, sess):
+ sess.run(train_op)
+
+ def test2Workers(self):
+ num_workers = 2
+ replicas_to_aggregate = 2
+ num_ps = 2
+ workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ # Creates and returns all the workers.
+ sessions, graphs, train_ops = get_workers(num_workers,
+ replicas_to_aggregate,
+ workers)
+
+ # Chief should have already initialized all the variables.
+ var_0_g_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1_g_0 = graphs[0].get_tensor_by_name("v1:0")
+ local_step_0 = graphs[0].get_tensor_by_name("sync_rep_local_step:0")
+ self.assertAllEqual(0.0, var_0_g_0.eval(session=sessions[0]))
+ self.assertAllEqual(1.0, var_1_g_0.eval(session=sessions[0]))
+ self.assertAllEqual(0, local_step_0.eval(session=sessions[0]))
+
+ # Will just use session 1 to verify all the variables later.
+ var_0_g_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_g_1 = graphs[1].get_tensor_by_name("v1:0")
+ var_sparse_g_1 = graphs[1].get_tensor_by_name("v_sparse:0")
+ local_step_1 = graphs[1].get_tensor_by_name("sync_rep_local_step:0")
+ global_step = graphs[1].get_tensor_by_name("global_step:0")
+
+ # The steps should also be initialized.
+ self.assertAllEqual(0, global_step.eval(session=sessions[1]))
+ self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
+ self.assertAllClose([[3.0], [4.0]],
+ var_sparse_g_1.eval(session=sessions[1]))
+
+ # We have initial tokens in the queue so we can call this one by one. After
+ # the first step, this will no longer work as there will be no more extra
+ # tokens in the queue.
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+
+ # The global step should have been updated and the variables should now have
+ # the new values after the average of the gradients are applied.
+ self.assertAllEqual(1, global_step.eval(session=sessions[1]))
+ self.assertAllClose(0-(0.1+0.3)/2*2.0, var_0_g_1.eval(session=sessions[1]))
+ self.assertAllClose(1-(0.9+1.1)/2*2.0, var_1_g_1.eval(session=sessions[1]))
+ self.assertAllClose([[3.0], [4.0-(0.1+0.3)/2*2.0]],
+ var_sparse_g_1.eval(session=sessions[1]))
+
+ # The local step for both workers should still be 0 because the initial
+ # tokens in the token queue are 0s. This means that the following
+ # computation of the gradients will be wasted as local_step is smaller than
+ # the current global step. However, this only happens once when the system
+ # just starts and this is necessary to make the system robust for the case
+ # when chief gets restarted by errors/preemption/...
+ self.assertAllEqual(0, local_step_0.eval(session=sessions[0]))
+ self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+ # Although the global step should still be 1 as explained above, the local
+ # step should now be updated to 1. The variables are still the same.
+ self.assertAllEqual(1, global_step.eval(session=sessions[1]))
+ self.assertAllEqual(1, local_step_0.eval(session=sessions[0]))
+ self.assertAllEqual(1, local_step_1.eval(session=sessions[1]))
+ self.assertAllClose(0-(0.1+0.3)/2*2.0, var_0_g_1.eval(session=sessions[1]))
+ self.assertAllClose(1-(0.9+1.1)/2*2.0, var_1_g_1.eval(session=sessions[1]))
+
+ # At this step, the token queue is empty. So the 2 workers need to work
+ # together to proceed.
+ threads = []
+ threads.append(self.checkedThread(target=self._run,
+ args=(train_ops[0], sessions[0])))
+ threads.append(self.checkedThread(target=self._run,
+ args=(train_ops[1], sessions[1])))
+
+ # The two workers starts to execute the train op.
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # The global step should now be 2 and the gradients should have been
+ # applied twice.
+ self.assertAllEqual(2, global_step.eval(session=sessions[1]))
+ self.assertAllClose(0 - 2 * (0.1 + 0.3) / 2 * 2.0,
+ var_0_g_1.eval(session=sessions[1]))
+ self.assertAllClose(1 - 2 * (0.9 + 1.1) / 2 * 2.0,
+ var_1_g_1.eval(session=sessions[1]))
+
+ # 3 workers and one of them is backup.
+ def test3Workers1Backup(self):
+ num_workers = 3
+ replicas_to_aggregate = 2
+ num_ps = 2
+ workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ # Creates and returns all the workers.
+ sessions, graphs, train_ops = get_workers(num_workers,
+ replicas_to_aggregate,
+ workers)
+
+ # Chief should have already initialized all the variables.
+ var_0_g_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_g_1 = graphs[1].get_tensor_by_name("v1:0")
+ local_step_1 = graphs[1].get_tensor_by_name("sync_rep_local_step:0")
+ global_step = graphs[1].get_tensor_by_name("global_step:0")
+
+ # The steps should also be initilized.
+ self.assertAllEqual(0, global_step.eval(session=sessions[1]))
+ self.assertAllEqual(0, local_step_1.eval(session=sessions[1]))
+
+ # We have initial tokens in the queue so we can call this one by one. After
+ # the token queue becomes empty, they should be called concurrently.
+ # Here worker 0 and worker 2 finished first.
+ sessions[0].run(train_ops[0])
+ sessions[2].run(train_ops[2])
+
+ # The global step should have been updated since we only need to collect 2
+ # gradients. The variables should now have the new values after the average
+ # of the gradients from worker 0/2 are applied.
+ self.assertAllEqual(1, global_step.eval(session=sessions[1]))
+ self.assertAllClose(0-(0.1+0.5)/2*2.0, var_0_g_1.eval(session=sessions[1]))
+ self.assertAllClose(1-(0.9+1.3)/2*2.0, var_1_g_1.eval(session=sessions[1]))
+
+ # Worker 1 finished later and its gradients will now be dropped as it is
+ # stale.
+ sessions[1].run(train_ops[1])
+
+ # As shown in the previous test, the local_step for all workers should be
+ # still 0 so their next computation will also be dropped.
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+ sessions[2].run(train_ops[2])
+
+ # Although the global step should still be 1 as explained above, the local
+ # step should now be updated to 1. Just check worker 1 as an example.
+ self.assertAllEqual(1, global_step.eval(session=sessions[1]))
+ self.assertAllEqual(1, local_step_1.eval(session=sessions[1]))
+
+ thread_0 = self.checkedThread(target=self._run,
+ args=(train_ops[0], sessions[0]))
+ thread_1 = self.checkedThread(target=self._run,
+ args=(train_ops[1], sessions[1]))
+
+ # Lets worker 0 execute first.
+ # It will wait as we need 2 workers to finish this step and the global step
+ # should be still 1.
+ thread_0.start()
+ self.assertAllEqual(1, global_step.eval(session=sessions[1]))
+
+ # Starts worker 1.
+ thread_1.start()
+ thread_1.join()
+
+ # The global step should now be 2 and the gradients should have been
+ # applied again.
+ self.assertAllEqual(2, global_step.eval(session=sessions[1]))
+ self.assertAllClose(-0.6 -(0.1 + 0.3) / 2 * 2.0,
+ var_0_g_1.eval(session=sessions[1]))
+ self.assertAllClose(-1.2 - (0.9 + 1.1) / 2 * 2.0,
+ var_1_g_1.eval(session=sessions[1]))
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index a814eb99ce..284cc43bc4 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -182,6 +182,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
from tensorflow.python.training.proximal_gradient_descent import ProximalGradientDescentOptimizer
from tensorflow.python.training.sync_replicas_optimizer import SyncReplicasOptimizer
+from tensorflow.python.training.sync_replicas_optimizer import SyncReplicasOptimizerV2
# Utility classes for training.
from tensorflow.python.training.coordinator import Coordinator