diff options
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10.py | 41 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 6 |
2 files changed, 34 insertions, 13 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index 8941388002..4cd7e006c1 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -53,6 +53,8 @@ tf.app.flags.DEFINE_integer('batch_size', 128, """Number of images to process in a batch.""") tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', """Path to the CIFAR-10 data directory.""") +tf.app.flags.DEFINE_boolean('use_fp16', False, + """Train the model using fp16.""") # Global constants describing the CIFAR-10 data set. IMAGE_SIZE = cifar10_input.IMAGE_SIZE @@ -105,7 +107,8 @@ def _variable_on_cpu(name, shape, initializer): Variable Tensor """ with tf.device('/cpu:0'): - var = tf.get_variable(name, shape, initializer=initializer) + dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 + var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) return var @@ -125,8 +128,11 @@ def _variable_with_weight_decay(name, shape, stddev, wd): Returns: Variable Tensor """ - var = _variable_on_cpu(name, shape, - tf.truncated_normal_initializer(stddev=stddev)) + dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 + var = _variable_on_cpu( + name, + shape, + tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) @@ -146,8 +152,12 @@ def distorted_inputs(): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') - return cifar10_input.distorted_inputs(data_dir=data_dir, - batch_size=FLAGS.batch_size) + images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, + batch_size=FLAGS.batch_size) + if FLAGS.use_fp16: + images = tf.cast(images, tf.float16) + labels = tf.cast(labels, tf.float16) + return images, labels def inputs(eval_data): @@ -166,8 +176,13 @@ def inputs(eval_data): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') - return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, - batch_size=FLAGS.batch_size) + images, labels = cifar10_input.inputs(eval_data=eval_data, + data_dir=data_dir, + batch_size=FLAGS.batch_size) + if FLAGS.use_fp16: + images = tf.cast(images, tf.float16) + labels = tf.cast(labels, tf.float16) + return images, labels def inference(images): @@ -186,8 +201,10 @@ def inference(images): # # conv1 with tf.variable_scope('conv1') as scope: - kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], - stddev=1e-4, wd=0.0) + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 3, 64], + stddev=5e-2, + wd=0.0) conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) @@ -203,8 +220,10 @@ def inference(images): # conv2 with tf.variable_scope('conv2') as scope: - kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64], - stddev=1e-4, wd=0.0) + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 64, 64], + stddev=5e-2, + wd=0.0) conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) bias = tf.nn.bias_add(conv, biases) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 6ecd20ad48..9cc1d917ab 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -266,8 +266,10 @@ class ExponentialMovingAverage(object): if var_list is None: var_list = variables.trainable_variables() for var in var_list: - if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]: - raise TypeError("The variables must be float or double: %s" % var.name) + if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32, + dtypes.float64]: + raise TypeError("The variables must be half, float, or double: %s" % + var.name) if var in self._averages: raise ValueError("Moving average already computed for: %s" % var.name) |