From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- .../python/kernel_tests/variable_scope_test.py | 160 +++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 tensorflow/python/kernel_tests/variable_scope_test.py (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py') diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py new file mode 100644 index 0000000000..bb538198ea --- /dev/null +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -0,0 +1,160 @@ +"""Tests for variable store.""" +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.python.ops import variable_scope + + +class VariableStoreTest(tf.test.TestCase): + + def testGetVar(self): + vs = variable_scope._get_default_variable_store() + v = vs.get_variable("v", [1]) + v1 = vs.get_variable("v", [1]) + assert v == v1 + + def testNameExists(self): + vs = variable_scope._get_default_variable_store() + # No check by default, so we can both create and get existing names. + v = vs.get_variable("v", [1]) + v1 = vs.get_variable("v", [1]) + assert v == v1 + # When reuse is False, we fail when variables are already there. + vs.get_variable("w", [1], reuse=False) # That's ok. + with self.assertRaises(ValueError): + vs.get_variable("v", [1], reuse=False) # That fails. + # When reuse is True, we fail when variables are new. + vs.get_variable("v", [1], reuse=True) # That's ok. + with self.assertRaises(ValueError): + vs.get_variable("u", [1], reuse=True) # That fails. + + def testNamelessStore(self): + vs = variable_scope._get_default_variable_store() + vs.get_variable("v1", [2]) + vs.get_variable("v2", [2]) + expected_names = ["%s:0" % name for name in ["v1", "v2"]] + self.assertEqual(set(expected_names), + set([v.name for v in vs._vars.values()])) + + def testVarScopeIntializer(self): + with self.test_session() as sess: + init = tf.constant_initializer(0.3) + with variable_scope.variable_scope("tower") as tower: + with variable_scope.variable_scope("foo", initializer=init): + v = variable_scope.get_variable("v", []) + sess.run(tf.initialize_variables([v])) + self.assertAllClose(v.eval(), 0.3) + with variable_scope.variable_scope(tower, initializer=init): + w = variable_scope.get_variable("w", []) + sess.run(tf.initialize_variables([w])) + self.assertAllClose(w.eval(), 0.3) + + def testGetVariableScope(self): + # Test the get_variable_scope() function and setting properties of result. + with self.test_session() as sess: + init = tf.constant_initializer(0.3) + with variable_scope.variable_scope("foo"): + new_init1 = variable_scope.get_variable_scope().initializer + self.assertEqual(new_init1, None) + # Check that we can set initializer like this. + variable_scope.get_variable_scope().set_initializer(init) + v = variable_scope.get_variable("v", []) + sess.run(tf.initialize_variables([v])) + self.assertAllClose(v.eval(), 0.3) + # Check that we can set reuse. + variable_scope.get_variable_scope().reuse_variables() + with self.assertRaises(ValueError): # Fail, w does not exist yet. + variable_scope.get_variable("w", [1]) + # Check that the set initializer goes away. + new_init = variable_scope.get_variable_scope().initializer + self.assertEqual(new_init, None) + + def testVarScope(self): + with self.test_session(): + with variable_scope.variable_scope("tower") as tower: + self.assertEqual(tower.name, "tower") + with tf.name_scope("scope") as sc: + self.assertEqual(sc, "tower/scope/") + + with variable_scope.variable_scope("foo"): + with variable_scope.variable_scope("bar") as bar: + self.assertEqual(bar.name, "foo/bar") + with tf.name_scope("scope") as sc: + self.assertEqual(sc, "foo/bar/scope/") + + with variable_scope.variable_scope("foo"): + with variable_scope.variable_scope(tower, reuse=True) as tower_shared: + self.assertEqual(tower_shared.name, "tower") + with tf.name_scope("scope") as sc: + self.assertEqual(sc, "foo_1/scope/") + + def testVarScopeNameScope(self): + with self.test_session(): + with tf.name_scope("scope1"): + with variable_scope.variable_scope("tower") as tower: + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope1/tower/scope2/") + with variable_scope.variable_scope("tower"): # Re-enter adds suffix. + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope1/tower_1/scope2/") + + with tf.name_scope("scope3"): + with variable_scope.variable_scope("tower"): + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope3/tower/scope2/") + with variable_scope.variable_scope(tower): + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope3/scope2/") + + def testVarScopeGetVar(self): + with self.test_session(): + with variable_scope.variable_scope("root"): + with variable_scope.variable_scope("towerA") as tower_a: + va = variable_scope.get_variable("v", [1]) + self.assertEqual(va.name, "root/towerA/v:0") + + with variable_scope.variable_scope(tower_a, reuse=True): + va2 = variable_scope.get_variable("v", [1]) + self.assertEqual(va2, va) + + with variable_scope.variable_scope("towerB"): + vb = variable_scope.get_variable("v", [1]) + self.assertEqual(vb.name, "root/towerB/v:0") + + with self.assertRaises(ValueError) as exc: + with variable_scope.variable_scope("towerA"): + va2 = variable_scope.get_variable("v", [1]) + self.assertEqual(exc.exception.message[:12], "Over-sharing") + + with variable_scope.variable_scope("towerA", reuse=True): + va2 = variable_scope.get_variable("v", [1]) + self.assertEqual(va2, va) + + with variable_scope.variable_scope("foo"): + with variable_scope.variable_scope("bar"): + v = variable_scope.get_variable("v", [1]) + self.assertEqual(v.name, "root/foo/bar/v:0") + with variable_scope.variable_scope(tower_a, reuse=True): + va3 = variable_scope.get_variable("v", [1]) + self.assertEqual(va, va3) + + with self.assertRaises(ValueError) as exc: + with variable_scope.variable_scope(tower_a, reuse=True): + with variable_scope.variable_scope("baz"): + variable_scope.get_variable("v", [1]) + self.assertEqual(exc.exception.message[:13], "Under-sharing") + + with self.assertRaises(ValueError) as exc: + with variable_scope.variable_scope(tower_a, reuse=True): + variable_scope.get_variable("v", [2]) # Different shape. + self.assertEqual("shape" in exc.exception.message, True) + + with self.assertRaises(ValueError) as exc: + with variable_scope.variable_scope(tower_a, reuse=True): + variable_scope.get_variable("v", [1], dtype=tf.int32) + self.assertEqual("dtype" in exc.exception.message, True) + + +if __name__ == "__main__": + tf.test.main() -- cgit v1.2.3