aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10.py
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-06-23 15:36:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-23 16:47:39 -0700
commitbcb8f33b788bd02f1b6bdf813129d5668c6c4e5e (patch)
tree7a366b20c15e58bda5f3afc585e8890b8e3219e9 /tensorflow/models/image/cifar10/cifar10.py
parentf40302caddb44249c665b9a66e27a33f06167264 (diff)
Added an option to train cifar10 using fp16.
Also updated the weight initialization to ensure that the initial weights are likely to be greater than the smaller possible value representable using a 16 bit float (~0.00006). Change: 125733195
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py41
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 8941388002..4cd7e006c1 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -53,6 +53,8 @@ tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
+tf.app.flags.DEFINE_boolean('use_fp16', False,
+ """Train the model using fp16.""")
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
@@ -105,7 +107,8 @@ def _variable_on_cpu(name, shape, initializer):
Variable Tensor
"""
with tf.device('/cpu:0'):
- var = tf.get_variable(name, shape, initializer=initializer)
+ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+ var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
@@ -125,8 +128,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
Returns:
Variable Tensor
"""
- var = _variable_on_cpu(name, shape,
- tf.truncated_normal_initializer(stddev=stddev))
+ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+ var = _variable_on_cpu(
+ name,
+ shape,
+ tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
@@ -146,8 +152,12 @@ def distorted_inputs():
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
- return cifar10_input.distorted_inputs(data_dir=data_dir,
- batch_size=FLAGS.batch_size)
+ images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
+ if FLAGS.use_fp16:
+ images = tf.cast(images, tf.float16)
+ labels = tf.cast(labels, tf.float16)
+ return images, labels
def inputs(eval_data):
@@ -166,8 +176,13 @@ def inputs(eval_data):
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
- return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir,
- batch_size=FLAGS.batch_size)
+ images, labels = cifar10_input.inputs(eval_data=eval_data,
+ data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
+ if FLAGS.use_fp16:
+ images = tf.cast(images, tf.float16)
+ labels = tf.cast(labels, tf.float16)
+ return images, labels
def inference(images):
@@ -186,8 +201,10 @@ def inference(images):
#
# conv1
with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64],
- stddev=1e-4, wd=0.0)
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 3, 64],
+ stddev=5e-2,
+ wd=0.0)
conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
bias = tf.nn.bias_add(conv, biases)
@@ -203,8 +220,10 @@ def inference(images):
# conv2
with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64],
- stddev=1e-4, wd=0.0)
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 64, 64],
+ stddev=5e-2,
+ wd=0.0)
conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
bias = tf.nn.bias_add(conv, biases)