diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 0c524a7f80..58772d9a23 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -22,6 +22,7 @@ import numpy import tensorflow as tf from tensorflow.python.framework import dtypes +from tensorflow.python.ops import init_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope @@ -95,6 +96,21 @@ class VariableScopeTest(tf.test.TestCase): with self.assertRaises(TypeError): tf.get_variable("x", initializer={}) + def testInitFromNonInitializer(self): + with self.test_session() as sess: + # Test various dtypes with zeros initializer as following: + types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64, + tf.bool] + + # Use different varibale_name to distinguish various dtypes + for (i, dtype) in enumerate(types): + x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype) + y = tf.get_variable(name='y%d' % i, shape=(3, 4), dtype=dtype, + initializer=init_ops.zeros_initializer(dtype=dtype)) + + tf.global_variables_initializer().run() + self.assertAllEqual(x.eval(), y.eval()) + def testVarScopeCachingDevice(self): with self.test_session(): caching_device = "/job:moo" @@ -672,6 +688,27 @@ def axis0_into3_partitioner(shape=None, **unused_kwargs): class VariableScopeWithPartitioningTest(tf.test.TestCase): + def testInitFromNonInitializer(self): + with self.test_session() as sess: + # Test various dtypes with zeros initializer as following: + types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64, + tf.bool] + + # Use different varibale_name to distinguish various dtypes + for (i, dtype) in enumerate(types): + x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype, + partitioner=axis0_into2_partitioner) + y = tf.get_variable(name='y%d' % i, shape=(6, 4), dtype=dtype, + partitioner=axis0_into2_partitioner, + initializer=init_ops.zeros_initializer(dtype=dtype)) + + tf.global_variables_initializer().run() + # x and y would become var list after partition + val_x = sess.run(list(x)) + val_y = sess.run(list(y)) + + self.assertAllEqual(val_x, val_y) + def testResultNameMatchesRequested(self): with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner): v = tf.get_variable("name0", shape=(3, 1, 1)) |