From 4903865bb5f2a42cfd5b2ecfb11c37bed28edcc8 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 16 Feb 2017 11:41:46 -0800 Subject: Integrate SyncReplicasOptimizer with Estimators. There is an hidden dependency between when 'apply_gradient' and get_chief_queue_runner() are called. This cl postpones creation of the queue to the initialization of the Session. In Estimator, Session is created after forming the graph/training-op. That means it is after the apply_gradient. Change: 147746938 --- .../python/training/sync_replicas_optimizer.py | 53 ++++++++++++++-------- .../training/sync_replicas_optimizer_test.py | 25 ++++++++++ 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 928aff64f9..3cee0b2f59 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -123,6 +123,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer): while not mon_sess.should_stop(): mon_sess.run(training_op) ``` + + To use SyncReplicasOptimizer with an `Estimator`, you need to send + sync_replicas_hook while calling the fit. + ``` + my_estimator = DNNClassifier(..., optimizer=opt) + my_estimator.fit(..., hooks=[sync_replicas_hook]) + ``` """ def __init__(self, @@ -418,34 +425,42 @@ class SyncReplicasOptimizer(optimizer.Optimizer): def make_session_run_hook(self, is_chief, num_tokens=-1): """Creates a hook to handle SyncReplicasHook ops such as initialization.""" - if is_chief: - return _SyncReplicasOptimizerHook(self.chief_init_op, - self.ready_for_local_init_op, - self.get_chief_queue_runner(), - self.get_init_tokens_op(num_tokens)) - - return _SyncReplicasOptimizerHook(self.local_step_init_op, - self.ready_for_local_init_op, None, None) + return _SyncReplicasOptimizerHook(self, is_chief, num_tokens) class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook): """A SessionRunHook handles ops related to SyncReplicasOptimizer.""" - def __init__(self, local_init_op, ready_for_local_init_op, q_runner, - init_tokens_op): + def __init__(self, sync_optimizer, is_chief, num_tokens): """Creates hook to handle SyncReplicaOptimizer initialization ops. Args: - local_init_op: Either `SyncReplicasOptimizer.chief_init_op` or - `SyncReplicasOptimizer.local_step_init_op`. - ready_for_local_init_op: `SyncReplicasOptimizer.ready_for_local_init_op` - q_runner: Either `SyncReplicasOptimizer.get_chief_queue_runner` or `None` - init_tokens_op: `SyncReplicasOptimizer.get_init_tokens_op` or None + sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize. + is_chief: `Bool`, whether is this a chief replica or not. + num_tokens: Number of tokens to add to the queue. """ - self._local_init_op = local_init_op - self._ready_for_local_init_op = ready_for_local_init_op - self._q_runner = q_runner - self._init_tokens_op = init_tokens_op + self._sync_optimizer = sync_optimizer + self._is_chief = is_chief + self._num_tokens = num_tokens + + def begin(self): + if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access + raise ValueError( + "SyncReplicasOptimizer.apply_gradient should be called before using " + "the hook.") + if self._is_chief: + self._local_init_op = self._sync_optimizer.chief_init_op + self._ready_for_local_init_op = ( + self._sync_optimizer.ready_for_local_init_op) + self._q_runner = self._sync_optimizer.get_chief_queue_runner() + self._init_tokens_op = self._sync_optimizer.get_init_tokens_op( + self._num_tokens) + else: + self._local_init_op = self._sync_optimizer.local_step_init_op + self._ready_for_local_init_op = ( + self._sync_optimizer.ready_for_local_init_op) + self._q_runner = None + self._init_tokens_op = None def after_create_session(self, session, coord): """Runs SyncReplicasOptimizer initialization ops.""" diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py index 6da18391db..32cae70460 100644 --- a/tensorflow/python/training/sync_replicas_optimizer_test.py +++ b/tensorflow/python/training/sync_replicas_optimizer_test.py @@ -277,5 +277,30 @@ class SyncReplicasOptimizerTest(test.TestCase): sessions[1].run(var_1_g_1)) +class SyncReplicasOptimizerHookTest(test.TestCase): + + def testErrorIfUsedBeforeMinimizeCalled(self): + opt = training.SyncReplicasOptimizer( + opt=gradient_descent.GradientDescentOptimizer(1.0), + replicas_to_aggregate=1, + total_num_replicas=1) + hook = opt.make_session_run_hook(True) + with self.assertRaisesRegexp(ValueError, + "apply_gradient should be called"): + hook.begin() + + def testCanCreatedBeforeMinimizeCalled(self): + """This behavior is required to be integrated with Estimators.""" + opt = training.SyncReplicasOptimizer( + opt=gradient_descent.GradientDescentOptimizer(1.0), + replicas_to_aggregate=1, + total_num_replicas=1) + hook = opt.make_session_run_hook(True) + v = variables.Variable([0.]) + global_step = variables.Variable(0, name="global_step", trainable=False) + opt.minimize(v, global_step=global_step) + hook.begin() + + if __name__ == "__main__": test.main() -- cgit v1.2.3