aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variable_scope_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 0c524a7f80..58772d9a23 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -22,6 +22,7 @@ import numpy
import tensorflow as tf
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
@@ -95,6 +96,21 @@ class VariableScopeTest(tf.test.TestCase):
with self.assertRaises(TypeError):
tf.get_variable("x", initializer={})
+ def testInitFromNonInitializer(self):
+ with self.test_session() as sess:
+ # Test various dtypes with zeros initializer as following:
+ types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64,
+ tf.bool]
+
+ # Use different varibale_name to distinguish various dtypes
+ for (i, dtype) in enumerate(types):
+ x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype)
+ y = tf.get_variable(name='y%d' % i, shape=(3, 4), dtype=dtype,
+ initializer=init_ops.zeros_initializer(dtype=dtype))
+
+ tf.global_variables_initializer().run()
+ self.assertAllEqual(x.eval(), y.eval())
+
def testVarScopeCachingDevice(self):
with self.test_session():
caching_device = "/job:moo"
@@ -672,6 +688,27 @@ def axis0_into3_partitioner(shape=None, **unused_kwargs):
class VariableScopeWithPartitioningTest(tf.test.TestCase):
+ def testInitFromNonInitializer(self):
+ with self.test_session() as sess:
+ # Test various dtypes with zeros initializer as following:
+ types = [tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64,
+ tf.bool]
+
+ # Use different varibale_name to distinguish various dtypes
+ for (i, dtype) in enumerate(types):
+ x = tf.get_variable(name='x%d' % i, shape=(3, 4), dtype=dtype,
+ partitioner=axis0_into2_partitioner)
+ y = tf.get_variable(name='y%d' % i, shape=(6, 4), dtype=dtype,
+ partitioner=axis0_into2_partitioner,
+ initializer=init_ops.zeros_initializer(dtype=dtype))
+
+ tf.global_variables_initializer().run()
+ # x and y would become var list after partition
+ val_x = sess.run(list(x))
+ val_y = sess.run(list(y))
+
+ self.assertAllEqual(val_x, val_y)
+
def testResultNameMatchesRequested(self):
with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner):
v = tf.get_variable("name0", shape=(3, 1, 1))