aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-06-06 16:41:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-06 17:49:11 -0700
commit730d267164366ff44a6dc8302dfc3b5339791f0b (patch)
treeb6eb936e237e44ab8ce324b37a51a76f9463af9b
parenta00e5709b06050c57d431f8a9abf157f13a52ce3 (diff)
Added an option to train the example mnist model using 16 bit floats
Change: 124198415
-rw-r--r--tensorflow/models/image/mnist/convolutional.py48
1 files changed, 29 insertions, 19 deletions
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index 95e5347c62..1893e68121 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -48,9 +48,19 @@ EVAL_FREQUENCY = 100 # Number of steps between evaluations.
tf.app.flags.DEFINE_boolean("self_test", False, "True if running a self test.")
+tf.app.flags.DEFINE_boolean('use_fp16', False,
+ "Use half floats instead of full floats if True.")
FLAGS = tf.app.flags.FLAGS
+def data_type():
+ """Return the type of the activations, weights, and placeholder variables."""
+ if FLAGS.use_fp16:
+ return tf.float16
+ else:
+ return tf.float32
+
+
def maybe_download(filename):
"""Download the data from Yann's website, unless it's already here."""
if not tf.gfile.Exists(WORK_DIRECTORY):
@@ -142,11 +152,11 @@ def main(argv=None): # pylint: disable=unused-argument
# These placeholder nodes will be fed a batch of training data at each
# training step using the {feed_dict} argument to the Run() call below.
train_data_node = tf.placeholder(
- tf.float32,
+ data_type(),
shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
eval_data = tf.placeholder(
- tf.float32,
+ data_type(),
shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
# The variables below hold all the trainable weights. They are passed an
@@ -155,24 +165,24 @@ def main(argv=None): # pylint: disable=unused-argument
conv1_weights = tf.Variable(
tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32.
stddev=0.1,
- seed=SEED))
- conv1_biases = tf.Variable(tf.zeros([32]))
- conv2_weights = tf.Variable(
- tf.truncated_normal([5, 5, 32, 64],
- stddev=0.1,
- seed=SEED))
- conv2_biases = tf.Variable(tf.constant(0.1, shape=[64]))
+ seed=SEED, dtype=data_type()))
+ conv1_biases = tf.Variable(tf.zeros([32], dtype=data_type()))
+ conv2_weights = tf.Variable(tf.truncated_normal(
+ [5, 5, 32, 64], stddev=0.1,
+ seed=SEED, dtype=data_type()))
+ conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=data_type()))
fc1_weights = tf.Variable( # fully connected, depth 512.
- tf.truncated_normal(
- [IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
- stddev=0.1,
- seed=SEED))
- fc1_biases = tf.Variable(tf.constant(0.1, shape=[512]))
- fc2_weights = tf.Variable(
- tf.truncated_normal([512, NUM_LABELS],
+ tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
stddev=0.1,
- seed=SEED))
- fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS]))
+ seed=SEED,
+ dtype=data_type()))
+ fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=data_type()))
+ fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS],
+ stddev=0.1,
+ seed=SEED,
+ dtype=data_type()))
+ fc2_biases = tf.Variable(tf.constant(
+ 0.1, shape=[NUM_LABELS], dtype=data_type()))
# We will replicate the model structure for the training subgraph, as well
# as the evaluation subgraphs, while sharing the trainable parameters.
@@ -230,7 +240,7 @@ def main(argv=None): # pylint: disable=unused-argument
# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
- batch = tf.Variable(0)
+ batch = tf.Variable(0, dtype=data_type())
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
0.01, # Base learning rate.