diff options
author | Mustafa Ispir <ispir@google.com> | 2016-11-23 12:59:05 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-23 13:05:33 -0800 |
commit | 45009de073a203f54bff7cec0ea83f88c48237b9 (patch) | |
tree | 07309a0565e872efca9d6ff3f43418c6ece4a6a4 | |
parent | b793cfd8ed0675f77a710bd3b98001d15974ee25 (diff) |
Introduce MonitoredSession into the examples. cl #1
Change: 140063357
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_train.py | 82 |
1 files changed, 34 insertions, 48 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py index 7e54c74c72..45c0bbd9f0 100644 --- a/tensorflow/models/image/cifar10/cifar10_train.py +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -37,11 +37,8 @@ from __future__ import division from __future__ import print_function from datetime import datetime -import os.path import time -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.models.image.cifar10 import cifar10 @@ -60,7 +57,7 @@ tf.app.flags.DEFINE_boolean('log_device_placement', False, def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): - global_step = tf.Variable(0, trainable=False) + global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() @@ -76,50 +73,39 @@ def train(): # updates the model parameters. train_op = cifar10.train(loss, global_step) - # Create a saver. - saver = tf.train.Saver(tf.all_variables()) - - # Build the summary operation based on the TF collection of Summaries. - summary_op = tf.merge_all_summaries() - - # Build an initialization operation to run below. - init = tf.global_variables_initializer() - - # Start running operations on the Graph. - sess = tf.Session(config=tf.ConfigProto( - log_device_placement=FLAGS.log_device_placement)) - sess.run(init) - - # Start the queue runners. - tf.train.start_queue_runners(sess=sess) - - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) - - for step in xrange(FLAGS.max_steps): - start_time = time.time() - _, loss_value = sess.run([train_op, loss]) - duration = time.time() - start_time - - assert not np.isnan(loss_value), 'Model diverged with loss = NaN' - - if step % 10 == 0: - num_examples_per_step = FLAGS.batch_size - examples_per_sec = num_examples_per_step / duration - sec_per_batch = float(duration) - - format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' - 'sec/batch)') - print (format_str % (datetime.now(), step, loss_value, - examples_per_sec, sec_per_batch)) - - if step % 100 == 0: - summary_str = sess.run(summary_op) - summary_writer.add_summary(summary_str, step) - - # Save the model checkpoint periodically. - if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: - checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') - saver.save(sess, checkpoint_path, global_step=step) + class _LoggerHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + + def begin(self): + self._step = -1 + + def before_run(self, run_context): + self._step += 1 + self._start_time = time.time() + return tf.train.SessionRunArgs(loss) # Asks for loss value. + + def after_run(self, run_context, run_values): + duration = time.time() - self._start_time + loss_value = run_values.results + if self._step % 10 == 0: + num_examples_per_step = FLAGS.batch_size + examples_per_sec = num_examples_per_step / duration + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), self._step, loss_value, + examples_per_sec, sec_per_batch)) + + with tf.train.MonitoredTrainingSession( + checkpoint_dir=FLAGS.train_dir, + hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), + tf.train.NanTensorHook(loss), + _LoggerHook()], + config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) as mon_sess: + while not mon_sess.should_stop(): + mon_sess.run(train_op) def main(argv=None): # pylint: disable=unused-argument |