diff options
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables_test.py')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables_test.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index d6e1d03a56..eb0a2c2d8e 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -1053,5 +1053,28 @@ class AssignFromCheckpointFnTest(tf.test.TestCase): self.assertEqual(init_value0, var0.eval()) self.assertEqual(init_value1, var1.eval()) +class ZeroInitializerOpTest(tf.test.TestCase): + + def _testZeroInitializer(self, shape, initializer, use_init): + var = tf.Variable(initializer) + var_zero = tf.contrib.framework.zero_initializer(var) + with self.test_session() as sess: + with self.assertRaisesOpError("Attempting to use uninitialized value"): + var.eval() + if use_init: + sess.run(var.initializer) + with self.assertRaisesOpError("input is already initialized"): + var_zero.eval() + self.assertAllClose(np.ones(shape), var.eval()) + else: + var_zero.eval() + self.assertAllClose(np.zeros(shape), var.eval()) + + def testZeroInitializer(self): + for dtype in (tf.int32, tf.int64, tf.float32, tf.float64): + for use_init in (False, True): + self._testZeroInitializer( + [10, 20], tf.ones([10, 20], dtype = dtype), use_init) + if __name__ == '__main__': tf.test.main() |