aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-11-23 12:59:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 13:05:33 -0800
commit45009de073a203f54bff7cec0ea83f88c48237b9 (patch)
tree07309a0565e872efca9d6ff3f43418c6ece4a6a4
parentb793cfd8ed0675f77a710bd3b98001d15974ee25 (diff)
Introduce MonitoredSession into the examples. cl #1
Change: 140063357
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py82
1 files changed, 34 insertions, 48 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
index 7e54c74c72..45c0bbd9f0 100644
--- a/tensorflow/models/image/cifar10/cifar10_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -37,11 +37,8 @@ from __future__ import division
from __future__ import print_function
from datetime import datetime
-import os.path
import time
-import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
@@ -60,7 +57,7 @@ tf.app.flags.DEFINE_boolean('log_device_placement', False,
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
- global_step = tf.Variable(0, trainable=False)
+ global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
@@ -76,50 +73,39 @@ def train():
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
- # Create a saver.
- saver = tf.train.Saver(tf.all_variables())
-
- # Build the summary operation based on the TF collection of Summaries.
- summary_op = tf.merge_all_summaries()
-
- # Build an initialization operation to run below.
- init = tf.global_variables_initializer()
-
- # Start running operations on the Graph.
- sess = tf.Session(config=tf.ConfigProto(
- log_device_placement=FLAGS.log_device_placement))
- sess.run(init)
-
- # Start the queue runners.
- tf.train.start_queue_runners(sess=sess)
-
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
-
- for step in xrange(FLAGS.max_steps):
- start_time = time.time()
- _, loss_value = sess.run([train_op, loss])
- duration = time.time() - start_time
-
- assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
-
- if step % 10 == 0:
- num_examples_per_step = FLAGS.batch_size
- examples_per_sec = num_examples_per_step / duration
- sec_per_batch = float(duration)
-
- format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
- 'sec/batch)')
- print (format_str % (datetime.now(), step, loss_value,
- examples_per_sec, sec_per_batch))
-
- if step % 100 == 0:
- summary_str = sess.run(summary_op)
- summary_writer.add_summary(summary_str, step)
-
- # Save the model checkpoint periodically.
- if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
- checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
- saver.save(sess, checkpoint_path, global_step=step)
+ class _LoggerHook(tf.train.SessionRunHook):
+ """Logs loss and runtime."""
+
+ def begin(self):
+ self._step = -1
+
+ def before_run(self, run_context):
+ self._step += 1
+ self._start_time = time.time()
+ return tf.train.SessionRunArgs(loss) # Asks for loss value.
+
+ def after_run(self, run_context, run_values):
+ duration = time.time() - self._start_time
+ loss_value = run_values.results
+ if self._step % 10 == 0:
+ num_examples_per_step = FLAGS.batch_size
+ examples_per_sec = num_examples_per_step / duration
+ sec_per_batch = float(duration)
+
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print (format_str % (datetime.now(), self._step, loss_value,
+ examples_per_sec, sec_per_batch))
+
+ with tf.train.MonitoredTrainingSession(
+ checkpoint_dir=FLAGS.train_dir,
+ hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
+ tf.train.NanTensorHook(loss),
+ _LoggerHook()],
+ config=tf.ConfigProto(
+ log_device_placement=FLAGS.log_device_placement)) as mon_sess:
+ while not mon_sess.should_stop():
+ mon_sess.run(train_op)
def main(argv=None): # pylint: disable=unused-argument