aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-02-06 14:34:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-06 14:45:09 -0800
commit7d85717cb148ddb5fa4c2ee1686dc78e26866079 (patch)
tree02b83e63952763c3479d764779721b74319fa479 /tensorflow/g3doc
parente7702046a3a5ca1906455730d9b5f270fb170af3 (diff)
Use MonitoredTrainingSession in distributed how-to.
Change: 146712044
Diffstat (limited to 'tensorflow/g3doc')
-rw-r--r--tensorflow/g3doc/how_tos/distributed/index.md37
1 files changed, 14 insertions, 23 deletions
diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md
index 880976ca8d..961b142170 100644
--- a/tensorflow/g3doc/how_tos/distributed/index.md
+++ b/tensorflow/g3doc/how_tos/distributed/index.md
@@ -213,37 +213,28 @@ def main(_):
# Build model...
loss = ...
- global_step = tf.Variable(0)
+ global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
- saver = tf.train.Saver()
- summary_op = tf.summary.merge_all()
- init_op = tf.global_variables_initializer()
-
- # Create a "supervisor", which oversees the training process.
- sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
- logdir="/tmp/train_logs",
- init_op=init_op,
- summary_op=summary_op,
- saver=saver,
- global_step=global_step,
- save_model_secs=600)
-
- # The supervisor takes care of session initialization, restoring from
- # a checkpoint, and closing when done or an error occurs.
- with sv.managed_session(server.target) as sess:
- # Loop until the supervisor shuts down or 1000000 steps have completed.
- step = 0
- while not sv.should_stop() and step < 1000000:
+ # The StopAtStepHook handles stopping after running given steps.
+ hooks=[tf.train.StopAtStepHook(last_step=1000000)]
+
+ # The MonitoredTrainingSession takes care of session initialization,
+ # restoring from a checkpoint, saving to a checkpoint, and closing when done
+ # or an error occurs.
+ with tf.train.MonitoredTrainingSession(master=server.target,
+ is_chief=(FLAGS.task_index == 0),
+ checkpoint_dir="/tmp/train_logs",
+ hooks=hooks) as mon_sess:
+ while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
- _, step = sess.run([train_op, global_step])
+ # mon_sess.run handles AbortedError in case of preempted PS.
+ mon_sess.run(train_op)
- # Ask for all the services to stop.
- sv.stop()
if __name__ == "__main__":
parser = argparse.ArgumentParser()