diff options
Diffstat (limited to 'tensorflow/python/training/optimizer_test.py')
-rw-r--r-- | tensorflow/python/training/optimizer_test.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 13e8cb9b25..ab4eecf7be 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -113,6 +113,33 @@ class OptimizerTest(tf.test.TestCase): # var1 has no gradient sgd_op.minimize(cost, global_step, [var1]) + def testGradientsAsVariables(self): + for dtype in [tf.half, tf.float32, tf.float64]: + with self.test_session() as sess: + var0 = tf.Variable([1.0, 2.0], dtype=dtype) + var1 = tf.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + 3 * var1 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + sgd_op = tf.train.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1]) + # Convert gradients to tf.Variables + converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars] + convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)] + + converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) + opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step) + + tf.initialize_all_variables().run() + # Run convert_ops to achieve the gradietns converting + sess.run(convert_ops) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([-14., -13.], var0.eval()) + self.assertAllClose([-6., -5.], var1.eval()) if __name__ == '__main__': tf.test.main() |