aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py41
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 8941388002..4cd7e006c1 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -53,6 +53,8 @@ tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
+tf.app.flags.DEFINE_boolean('use_fp16', False,
+ """Train the model using fp16.""")
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
@@ -105,7 +107,8 @@ def _variable_on_cpu(name, shape, initializer):
Variable Tensor
"""
with tf.device('/cpu:0'):
- var = tf.get_variable(name, shape, initializer=initializer)
+ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+ var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
@@ -125,8 +128,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
Returns:
Variable Tensor
"""
- var = _variable_on_cpu(name, shape,
- tf.truncated_normal_initializer(stddev=stddev))
+ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+ var = _variable_on_cpu(
+ name,
+ shape,
+ tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
@@ -146,8 +152,12 @@ def distorted_inputs():
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
- return cifar10_input.distorted_inputs(data_dir=data_dir,
- batch_size=FLAGS.batch_size)
+ images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
+ if FLAGS.use_fp16:
+ images = tf.cast(images, tf.float16)
+ labels = tf.cast(labels, tf.float16)
+ return images, labels
def inputs(eval_data):
@@ -166,8 +176,13 @@ def inputs(eval_data):
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
- return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir,
- batch_size=FLAGS.batch_size)
+ images, labels = cifar10_input.inputs(eval_data=eval_data,
+ data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
+ if FLAGS.use_fp16:
+ images = tf.cast(images, tf.float16)
+ labels = tf.cast(labels, tf.float16)
+ return images, labels
def inference(images):
@@ -186,8 +201,10 @@ def inference(images):
#
# conv1
with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64],
- stddev=1e-4, wd=0.0)
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 3, 64],
+ stddev=5e-2,
+ wd=0.0)
conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
bias = tf.nn.bias_add(conv, biases)
@@ -203,8 +220,10 @@ def inference(images):
# conv2
with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64],
- stddev=1e-4, wd=0.0)
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 64, 64],
+ stddev=5e-2,
+ wd=0.0)
conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
bias = tf.nn.bias_add(conv, biases)