aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-02-16 11:41:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 12:08:38 -0800
commit4903865bb5f2a42cfd5b2ecfb11c37bed28edcc8 (patch)
tree81b0699c18be342a433c2805e2dfa8ef084cc7e0
parent6640d3f3de88a3f3ade8ec6e5e4540e545024f87 (diff)
Integrate SyncReplicasOptimizer with Estimators.
There is an hidden dependency between when 'apply_gradient' and get_chief_queue_runner() are called. This cl postpones creation of the queue to the initialization of the Session. In Estimator, Session is created after forming the graph/training-op. That means it is after the apply_gradient. Change: 147746938
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py53
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py25
2 files changed, 59 insertions, 19 deletions
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index 928aff64f9..3cee0b2f59 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -123,6 +123,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
while not mon_sess.should_stop():
mon_sess.run(training_op)
```
+
+ To use SyncReplicasOptimizer with an `Estimator`, you need to send
+ sync_replicas_hook while calling the fit.
+ ```
+ my_estimator = DNNClassifier(..., optimizer=opt)
+ my_estimator.fit(..., hooks=[sync_replicas_hook])
+ ```
"""
def __init__(self,
@@ -418,34 +425,42 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
def make_session_run_hook(self, is_chief, num_tokens=-1):
"""Creates a hook to handle SyncReplicasHook ops such as initialization."""
- if is_chief:
- return _SyncReplicasOptimizerHook(self.chief_init_op,
- self.ready_for_local_init_op,
- self.get_chief_queue_runner(),
- self.get_init_tokens_op(num_tokens))
-
- return _SyncReplicasOptimizerHook(self.local_step_init_op,
- self.ready_for_local_init_op, None, None)
+ return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
"""A SessionRunHook handles ops related to SyncReplicasOptimizer."""
- def __init__(self, local_init_op, ready_for_local_init_op, q_runner,
- init_tokens_op):
+ def __init__(self, sync_optimizer, is_chief, num_tokens):
"""Creates hook to handle SyncReplicaOptimizer initialization ops.
Args:
- local_init_op: Either `SyncReplicasOptimizer.chief_init_op` or
- `SyncReplicasOptimizer.local_step_init_op`.
- ready_for_local_init_op: `SyncReplicasOptimizer.ready_for_local_init_op`
- q_runner: Either `SyncReplicasOptimizer.get_chief_queue_runner` or `None`
- init_tokens_op: `SyncReplicasOptimizer.get_init_tokens_op` or None
+ sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
+ is_chief: `Bool`, whether is this a chief replica or not.
+ num_tokens: Number of tokens to add to the queue.
"""
- self._local_init_op = local_init_op
- self._ready_for_local_init_op = ready_for_local_init_op
- self._q_runner = q_runner
- self._init_tokens_op = init_tokens_op
+ self._sync_optimizer = sync_optimizer
+ self._is_chief = is_chief
+ self._num_tokens = num_tokens
+
+ def begin(self):
+ if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access
+ raise ValueError(
+ "SyncReplicasOptimizer.apply_gradient should be called before using "
+ "the hook.")
+ if self._is_chief:
+ self._local_init_op = self._sync_optimizer.chief_init_op
+ self._ready_for_local_init_op = (
+ self._sync_optimizer.ready_for_local_init_op)
+ self._q_runner = self._sync_optimizer.get_chief_queue_runner()
+ self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
+ self._num_tokens)
+ else:
+ self._local_init_op = self._sync_optimizer.local_step_init_op
+ self._ready_for_local_init_op = (
+ self._sync_optimizer.ready_for_local_init_op)
+ self._q_runner = None
+ self._init_tokens_op = None
def after_create_session(self, session, coord):
"""Runs SyncReplicasOptimizer initialization ops."""
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index 6da18391db..32cae70460 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -277,5 +277,30 @@ class SyncReplicasOptimizerTest(test.TestCase):
sessions[1].run(var_1_g_1))
+class SyncReplicasOptimizerHookTest(test.TestCase):
+
+ def testErrorIfUsedBeforeMinimizeCalled(self):
+ opt = training.SyncReplicasOptimizer(
+ opt=gradient_descent.GradientDescentOptimizer(1.0),
+ replicas_to_aggregate=1,
+ total_num_replicas=1)
+ hook = opt.make_session_run_hook(True)
+ with self.assertRaisesRegexp(ValueError,
+ "apply_gradient should be called"):
+ hook.begin()
+
+ def testCanCreatedBeforeMinimizeCalled(self):
+ """This behavior is required to be integrated with Estimators."""
+ opt = training.SyncReplicasOptimizer(
+ opt=gradient_descent.GradientDescentOptimizer(1.0),
+ replicas_to_aggregate=1,
+ total_num_replicas=1)
+ hook = opt.make_session_run_hook(True)
+ v = variables.Variable([0.])
+ global_step = variables.Variable(0, name="global_step", trainable=False)
+ opt.minimize(v, global_step=global_step)
+ hook.begin()
+
+
if __name__ == "__main__":
test.main()