diff options
author | 2016-04-20 09:34:53 -0800 | |
---|---|---|
committer | 2016-04-20 10:42:25 -0700 | |
commit | 36c0475865ec103b5a3c2de4a69b17eddf7a9903 (patch) | |
tree | 812d2621008955ddb9ac7f140968f8de1574717e /tensorflow/python/training/optimizer_test.py | |
parent | 670a906496f8c5bcd4222f5a54b29f9ec13871f1 (diff) |
Enable fp16 support for all optimizers, and also add unit tests for all that
have it (ie., everything except sync_replicas_optimizer). Mostly a matter of
adding the right casts for all the constant tensors, so that the clients
do not need to explicitly set types for them.
Also add a helper function assertAllCloseAccordingToType() that does comparison
with a larger epsilon for fp16 (adapted from training_ops_test.py and improved
somewhat).
Change: 120351448
Diffstat (limited to 'tensorflow/python/training/optimizer_test.py')
-rw-r--r-- | tensorflow/python/training/optimizer_test.py | 108 |
1 files changed, 57 insertions, 51 deletions
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 174df4d654..e602de17b2 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -23,65 +23,71 @@ import tensorflow as tf class OptimizerTest(tf.test.TestCase): def testBasic(self): - with self.test_session(): - var0 = tf.Variable([1.0, 2.0]) - var1 = tf.Variable([3.0, 4.0]) - cost = 5 * var0 + 3 * var1 - global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') - sgd_op = tf.train.GradientDescentOptimizer(3.0) - opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) + for dtype in [tf.half, tf.float32]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + 3 * var1 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + sgd_op = tf.train.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) - tf.initialize_all_variables().run() - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 1 step of sgd through optimizer - opt_op.run() - # Validate updated params - self.assertAllClose([-14., -13.], var0.eval()) - self.assertAllClose([-6., -5.], var1.eval()) + tf.initialize_all_variables().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([-14., -13.], var0.eval()) + self.assertAllClose([-6., -5.], var1.eval()) def testAggregationMethod(self): - with self.test_session(): - var0 = tf.Variable([1.0, 2.0]) - var1 = tf.Variable([3.0, 4.0]) - cost = 5 * var0 + 3 * var1 - global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') - sgd_op = tf.train.GradientDescentOptimizer(3.0) - opt_op = sgd_op.minimize( - cost, global_step, [var0, var1], aggregation_method= - tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) + for dtype in [tf.half, tf.float32]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + 3 * var1 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + sgd_op = tf.train.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize( + cost, + global_step, + [var0, var1], + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) - tf.initialize_all_variables().run() - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 1 step of sgd through optimizer - opt_op.run() - # Validate updated params - self.assertAllClose([-14., -13.], var0.eval()) - self.assertAllClose([-6., -5.], var1.eval()) + tf.initialize_all_variables().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([-14., -13.], var0.eval()) + self.assertAllClose([-6., -5.], var1.eval()) def testNoVariables(self): - with self.test_session(): - var0 = tf.Variable([1.0, 2.0], trainable=False) - var1 = tf.Variable([3.0, 4.0], trainable=False) - cost = 5 * var0 + var1 - sgd_op = tf.train.GradientDescentOptimizer(3.0) - with self.assertRaisesRegexp(ValueError, 'No variables'): - sgd_op.minimize(cost) + for dtype in [tf.half, tf.float32]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype, trainable=False) + var1 = tf.Variable([3.0, 4.0], dtype=dtype, trainable=False) + cost = 5 * var0 + var1 + sgd_op = tf.train.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No variables'): + sgd_op.minimize(cost) def testNoGradients(self): - with self.test_session(): - var0 = tf.Variable([1.0, 2.0]) - var1 = tf.Variable([3.0, 4.0]) - cost = 5 * var0 - global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') - sgd_op = tf.train.GradientDescentOptimizer(3.0) - with self.assertRaisesRegexp(ValueError, 'No gradients'): - # var1 has no gradient - sgd_op.minimize(cost, global_step, [var1]) + for dtype in [tf.half, tf.float32]: + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + sgd_op = tf.train.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No gradients'): + # var1 has no gradient + sgd_op.minimize(cost, global_step, [var1]) -if __name__ == "__main__": +if __name__ == '__main__': tf.test.main() |