diff options
author | 2016-06-23 15:36:34 -0800 | |
---|---|---|
committer | 2016-06-23 16:47:39 -0700 | |
commit | bcb8f33b788bd02f1b6bdf813129d5668c6c4e5e (patch) | |
tree | 7a366b20c15e58bda5f3afc585e8890b8e3219e9 /tensorflow/models/image/cifar10/cifar10.py | |
parent | f40302caddb44249c665b9a66e27a33f06167264 (diff) |
Added an option to train cifar10 using fp16.
Also updated the weight initialization to ensure that the initial weights are likely to be greater than the smaller possible value representable using a 16 bit float (~0.00006).
Change: 125733195
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10.py')
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10.py | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index 8941388002..4cd7e006c1 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -53,6 +53,8 @@ tf.app.flags.DEFINE_integer('batch_size', 128, """Number of images to process in a batch.""") tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', """Path to the CIFAR-10 data directory.""") +tf.app.flags.DEFINE_boolean('use_fp16', False, + """Train the model using fp16.""") # Global constants describing the CIFAR-10 data set. IMAGE_SIZE = cifar10_input.IMAGE_SIZE @@ -105,7 +107,8 @@ def _variable_on_cpu(name, shape, initializer): Variable Tensor """ with tf.device('/cpu:0'): - var = tf.get_variable(name, shape, initializer=initializer) + dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 + var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) return var @@ -125,8 +128,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd): Returns: Variable Tensor """ - var = _variable_on_cpu(name, shape, - tf.truncated_normal_initializer(stddev=stddev)) + dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 + var = _variable_on_cpu( + name, + shape, + tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) @@ -146,8 +152,12 @@ def distorted_inputs(): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') - return cifar10_input.distorted_inputs(data_dir=data_dir, - batch_size=FLAGS.batch_size) + images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, + batch_size=FLAGS.batch_size) + if FLAGS.use_fp16: + images = tf.cast(images, tf.float16) + labels = tf.cast(labels, tf.float16) + return images, labels def inputs(eval_data): @@ -166,8 +176,13 @@ def inputs(eval_data): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') - return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, - batch_size=FLAGS.batch_size) + images, labels = cifar10_input.inputs(eval_data=eval_data, + data_dir=data_dir, + batch_size=FLAGS.batch_size) + if FLAGS.use_fp16: + images = tf.cast(images, tf.float16) + labels = tf.cast(labels, tf.float16) + return images, labels def inference(images): @@ -186,8 +201,10 @@ def inference(images): # # conv1 with tf.variable_scope('conv1') as scope: - kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], - stddev=1e-4, wd=0.0) + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 3, 64], + stddev=5e-2, + wd=0.0) conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) @@ -203,8 +220,10 @@ def inference(images): # conv2 with tf.variable_scope('conv2') as scope: - kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64], - stddev=1e-4, wd=0.0) + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 64, 64], + stddev=5e-2, + wd=0.0) conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) bias = tf.nn.bias_add(conv, biases) |