diff options
author | Mustafa Ispir <ispir@google.com> | 2017-02-06 14:34:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-06 14:45:09 -0800 |
commit | 7d85717cb148ddb5fa4c2ee1686dc78e26866079 (patch) | |
tree | 02b83e63952763c3479d764779721b74319fa479 /tensorflow/g3doc | |
parent | e7702046a3a5ca1906455730d9b5f270fb170af3 (diff) |
Use MonitoredTrainingSession in distributed how-to.
Change: 146712044
Diffstat (limited to 'tensorflow/g3doc')
-rw-r--r-- | tensorflow/g3doc/how_tos/distributed/index.md | 37 |
1 files changed, 14 insertions, 23 deletions
diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md index 880976ca8d..961b142170 100644 --- a/tensorflow/g3doc/how_tos/distributed/index.md +++ b/tensorflow/g3doc/how_tos/distributed/index.md @@ -213,37 +213,28 @@ def main(_): # Build model... loss = ... - global_step = tf.Variable(0) + global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) - saver = tf.train.Saver() - summary_op = tf.summary.merge_all() - init_op = tf.global_variables_initializer() - - # Create a "supervisor", which oversees the training process. - sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), - logdir="/tmp/train_logs", - init_op=init_op, - summary_op=summary_op, - saver=saver, - global_step=global_step, - save_model_secs=600) - - # The supervisor takes care of session initialization, restoring from - # a checkpoint, and closing when done or an error occurs. - with sv.managed_session(server.target) as sess: - # Loop until the supervisor shuts down or 1000000 steps have completed. - step = 0 - while not sv.should_stop() and step < 1000000: + # The StopAtStepHook handles stopping after running given steps. + hooks=[tf.train.StopAtStepHook(last_step=1000000)] + + # The MonitoredTrainingSession takes care of session initialization, + # restoring from a checkpoint, saving to a checkpoint, and closing when done + # or an error occurs. + with tf.train.MonitoredTrainingSession(master=server.target, + is_chief=(FLAGS.task_index == 0), + checkpoint_dir="/tmp/train_logs", + hooks=hooks) as mon_sess: + while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. - _, step = sess.run([train_op, global_step]) + # mon_sess.run handles AbortedError in case of preempted PS. + mon_sess.run(train_op) - # Ask for all the services to stop. - sv.stop() if __name__ == "__main__": parser = argparse.ArgumentParser() |