aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/gradient_descent_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/gradient_descent_test.py')
-rw-r--r--tensorflow/python/training/gradient_descent_test.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
new file mode 100644
index 0000000000..d5b0cae401
--- /dev/null
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -0,0 +1,105 @@
+"""Functional test for GradientDescent."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class GradientDescentOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.GradientDescentOptimizer(3.0)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testWithGlobalStep(self):
+ with self.test_session():
+ global_step = tf.Variable(0, trainable=False)
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+ self.assertAllClose(1, global_step.eval())
+
+ def testSparseBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([[1.0], [2.0]])
+ var1 = tf.Variable([[3.0], [4.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0 - 3.0 * 0.01]], var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()