aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/optimizer_test.py')
-rw-r--r--tensorflow/python/training/optimizer_test.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index 13e8cb9b25..ab4eecf7be 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -113,6 +113,33 @@ class OptimizerTest(tf.test.TestCase):
# var1 has no gradient
sgd_op.minimize(cost, global_step, [var1])
+ def testGradientsAsVariables(self):
+ for dtype in [tf.half, tf.float32, tf.float64]:
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0], dtype=dtype)
+ var1 = tf.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+ sgd_op = tf.train.GradientDescentOptimizer(3.0)
+ grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars]
+ convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)]
+
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step)
+
+ tf.initialize_all_variables().run()
+ # Run convert_ops to achieve the gradietns converting
+ sess.run(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
if __name__ == '__main__':
tf.test.main()