aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables_test.py')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index d6e1d03a56..eb0a2c2d8e 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -1053,5 +1053,28 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
self.assertEqual(init_value0, var0.eval())
self.assertEqual(init_value1, var1.eval())
+class ZeroInitializerOpTest(tf.test.TestCase):
+
+ def _testZeroInitializer(self, shape, initializer, use_init):
+ var = tf.Variable(initializer)
+ var_zero = tf.contrib.framework.zero_initializer(var)
+ with self.test_session() as sess:
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ var.eval()
+ if use_init:
+ sess.run(var.initializer)
+ with self.assertRaisesOpError("input is already initialized"):
+ var_zero.eval()
+ self.assertAllClose(np.ones(shape), var.eval())
+ else:
+ var_zero.eval()
+ self.assertAllClose(np.zeros(shape), var.eval())
+
+ def testZeroInitializer(self):
+ for dtype in (tf.int32, tf.int64, tf.float32, tf.float64):
+ for use_init in (False, True):
+ self._testZeroInitializer(
+ [10, 20], tf.ones([10, 20], dtype = dtype), use_init)
+
if __name__ == '__main__':
tf.test.main()