diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 82f010cf5b..23a7f3717c 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -197,6 +197,17 @@ class VariablesTestCase(tf.test.TestCase): self.assertAllClose(3.0, var_y.eval()) self.assertAllClose(5.0, tf.add(var_x, var_y).eval()) + def testZeroSizeVarSameAsConst(self): + with self.test_session(): + zero_size_var = tf.Variable(tf.zeros([0, 2])) + zero_size_const = tf.ones([2, 0]) + variable_mul = tf.matmul(zero_size_const, zero_size_var) + const_mul = tf.matmul(zero_size_const, zero_size_const, transpose_b=True) + tf.initialize_all_variables().run() + variable_output = variable_mul.eval() + self.assertAllClose(const_mul.eval(), variable_output) + self.assertAllClose([[0., 0.], [0., 0.]], variable_output) + def testCachingDevice(self): with self.test_session(): var = tf.Variable(2.0) @@ -387,6 +398,23 @@ class IsInitializedTest(tf.test.TestCase): v.initializer.run() self.assertEqual(0, sess.run(uninited).size) + def testZeroSizeVarInitialized(self): + with tf.Graph().as_default(), self.test_session() as sess: + v = tf.Variable(tf.zeros([0, 2]), name="v") + uninited = tf.report_uninitialized_variables() + v.initializer.run() # not strictly necessary + self.assertEqual(0, sess.run(uninited).size) + + def testTrainingWIthZeroSizeVar(self): + with tf.Graph().as_default(), self.test_session() as sess: + a = tf.Variable(tf.zeros([0, 2])) + b = tf.Variable(tf.ones([2, 2])) + objective = tf.reduce_sum(b + tf.matmul(a, a, transpose_a=True)) + tf.initialize_all_variables().run() + do_opt = tf.train.GradientDescentOptimizer(0.1).minimize(objective) + sess.run([do_opt]) + self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], b.eval()) + class ObsoleteIsInitializedTest(tf.test.TestCase): |