aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 82f010cf5b..23a7f3717c 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -197,6 +197,17 @@ class VariablesTestCase(tf.test.TestCase):
self.assertAllClose(3.0, var_y.eval())
self.assertAllClose(5.0, tf.add(var_x, var_y).eval())
+ def testZeroSizeVarSameAsConst(self):
+ with self.test_session():
+ zero_size_var = tf.Variable(tf.zeros([0, 2]))
+ zero_size_const = tf.ones([2, 0])
+ variable_mul = tf.matmul(zero_size_const, zero_size_var)
+ const_mul = tf.matmul(zero_size_const, zero_size_const, transpose_b=True)
+ tf.initialize_all_variables().run()
+ variable_output = variable_mul.eval()
+ self.assertAllClose(const_mul.eval(), variable_output)
+ self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
+
def testCachingDevice(self):
with self.test_session():
var = tf.Variable(2.0)
@@ -387,6 +398,23 @@ class IsInitializedTest(tf.test.TestCase):
v.initializer.run()
self.assertEqual(0, sess.run(uninited).size)
+ def testZeroSizeVarInitialized(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ v = tf.Variable(tf.zeros([0, 2]), name="v")
+ uninited = tf.report_uninitialized_variables()
+ v.initializer.run() # not strictly necessary
+ self.assertEqual(0, sess.run(uninited).size)
+
+ def testTrainingWIthZeroSizeVar(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ a = tf.Variable(tf.zeros([0, 2]))
+ b = tf.Variable(tf.ones([2, 2]))
+ objective = tf.reduce_sum(b + tf.matmul(a, a, transpose_a=True))
+ tf.initialize_all_variables().run()
+ do_opt = tf.train.GradientDescentOptimizer(0.1).minimize(objective)
+ sess.run([do_opt])
+ self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], b.eval())
+
class ObsoleteIsInitializedTest(tf.test.TestCase):