diff options
author | Ali Yahya <alive@google.com> | 2017-08-29 16:15:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 16:19:03 -0700 |
commit | 5a76b8dce388a0a03849979a7eb655d1780cfd88 (patch) | |
tree | 6f9b9bc331e5049fd940ff6114c91d98e4f8881f | |
parent | 0492dccf024be59f4f38f7650deb9570a2f0c2db (diff) |
Modified variable scopes to work with Eager mode.
PiperOrigin-RevId: 166920017
-rw-r--r-- | tensorflow/python/framework/test_util.py | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 584 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 56 |
6 files changed, 359 insertions, 296 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index be51a38daf..c65816a543 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -277,7 +277,7 @@ def enable_c_api(fn): def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, use_gpu=False, force_gpu=False, - reset_test=False): + reset_test=True): """Runs the test in both graph and eager modes. Args: @@ -305,6 +305,8 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, f(self) if reset_test: + # This decorator runs the wrapped test twice. + # Reset the test environment between runs. self.tearDown() self.setUp() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 2db686333d..bdd9a6dc22 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -932,12 +932,14 @@ tf_py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:state_ops", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], tags = ["no_windows"], ) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 98fc568515..c31732d807 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -53,7 +53,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): 0, dtype=dtypes.int32)).run() - def testReadVariableDtypeMismatch(self): + def testReadVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") @@ -62,7 +62,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): "Expected float got int32."): _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32) - def testAssignVariableDtypeMismatch(self): + def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 67932d0823..cdac12f05a 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -20,10 +20,12 @@ from __future__ import print_function import numpy +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops @@ -37,32 +39,38 @@ from tensorflow.python.platform import test class VariableScopeTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def testGetVar(self): vs = variable_scope._get_default_variable_store() v = vs.get_variable("v", [1]) v1 = vs.get_variable("v", [1]) self.assertEqual(v, v1) + @test_util.run_in_graph_and_eager_modes() def testResource(self): vs = variable_scope._get_default_variable_store() v1 = vs.get_variable("v", [1], use_resource=True) self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable)) + @test_util.run_in_graph_and_eager_modes() def testNameExists(self): vs = variable_scope._get_default_variable_store() # No check by default, so we can both create and get existing names. v = vs.get_variable("v", [1]) v1 = vs.get_variable("v", [1]) self.assertEqual(v, v1) - # When reuse is False, we fail when variables are already there. - vs.get_variable("w", [1], reuse=False) # That's ok. - with self.assertRaises(ValueError): - vs.get_variable("v", [1], reuse=False) # That fails. - # When reuse is True, we fail when variables are new. - vs.get_variable("v", [1], reuse=True) # That's ok. - with self.assertRaises(ValueError): - vs.get_variable("u", [1], reuse=True) # That fails. + if context.in_graph_mode(): + # When reuse is False, we fail when variables are already there. + vs.get_variable("w", [1], reuse=False) # That's ok. + with self.assertRaises(ValueError): + vs.get_variable("v", [1], reuse=False) # That fails. + # When reuse is True, we fail when variables are new. + vs.get_variable("v", [1], reuse=True) # That's ok. + with self.assertRaises(ValueError): + vs.get_variable("u", [1], reuse=True) # That fails. + + @test_util.run_in_graph_and_eager_modes() def testNamelessStore(self): vs = variable_scope._get_default_variable_store() vs.get_variable("v1", [2]) @@ -71,22 +79,23 @@ class VariableScopeTest(test.TestCase): self.assertEqual( set(expected_names), set([v.name for v in vs._vars.values()])) + @test_util.run_in_graph_and_eager_modes() def testVarScopeInitializer(self): - with self.test_session() as sess: - init = init_ops.constant_initializer(0.3) - with variable_scope.variable_scope("tower") as tower: - with variable_scope.variable_scope("foo", initializer=init): - v = variable_scope.get_variable("v", []) - sess.run(variables_lib.initialize_variables([v])) - self.assertAllClose(v.eval(), 0.3) - with variable_scope.variable_scope(tower, initializer=init): - w = variable_scope.get_variable("w", []) - sess.run(variables_lib.initialize_variables([w])) - self.assertAllClose(w.eval(), 0.3) + init = init_ops.constant_initializer(0.3) + with variable_scope.variable_scope("tower0") as tower: + with variable_scope.variable_scope("foo", initializer=init): + v = variable_scope.get_variable("v", []) + self.evaluate(variables_lib.variables_initializer([v])) + self.assertAllClose(self.evaluate(v.value()), 0.3) + with variable_scope.variable_scope(tower, initializer=init): + w = variable_scope.get_variable("w", []) + self.evaluate(variables_lib.variables_initializer([w])) + self.assertAllClose(self.evaluate(w.value()), 0.3) + @test_util.run_in_graph_and_eager_modes() def testVarScopeConstraint(self): constraint = lambda x: 0. * x - with variable_scope.variable_scope("tower") as tower: + with variable_scope.variable_scope("tower1") as tower: with variable_scope.variable_scope("foo", constraint=constraint): v = variable_scope.get_variable("v", []) self.assertEqual(v.constraint, constraint) @@ -94,51 +103,56 @@ class VariableScopeTest(test.TestCase): w = variable_scope.get_variable("w", []) self.assertEqual(w.constraint, constraint) + @test_util.run_in_graph_and_eager_modes() def testVarScopeDType(self): - with self.test_session(): - with variable_scope.variable_scope("tower") as tower: - with variable_scope.variable_scope("foo", dtype=dtypes.float16): - v = variable_scope.get_variable("v", []) - self.assertEqual(v.dtype.base_dtype, dtypes.float16) - with variable_scope.variable_scope(tower, dtype=dtypes.float16): - w = variable_scope.get_variable("w", []) - self.assertEqual(w.dtype.base_dtype, dtypes.float16) + with variable_scope.variable_scope("tower2") as tower: + with variable_scope.variable_scope("foo", dtype=dtypes.float16): + v = variable_scope.get_variable("v", []) + self.assertEqual(v.dtype.base_dtype, dtypes.float16) + with variable_scope.variable_scope(tower, dtype=dtypes.float16): + w = variable_scope.get_variable("w", []) + self.assertEqual(w.dtype.base_dtype, dtypes.float16) + @test_util.run_in_graph_and_eager_modes() def testInitFromNonTensorValue(self): - with self.test_session() as sess: - v = variable_scope.get_variable("v", initializer=4, dtype=dtypes.int32) - sess.run(variables_lib.initialize_variables([v])) - self.assertAllClose(v.eval(), 4) + v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) + self.evaluate(variables_lib.variables_initializer([v])) + self.assertAllClose(self.evaluate(v.value()), 4) - w = variable_scope.get_variable( - "w", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64) - sess.run(variables_lib.initialize_variables([w])) - self.assertAllClose(w.eval(), [1, 2, 3]) + w = variable_scope.get_variable( + "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64) + self.evaluate(variables_lib.variables_initializer([w])) + self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) + if context.in_graph_mode(): with self.assertRaises(TypeError): - variable_scope.get_variable("x", initializer={}) + variable_scope.get_variable("x4", initializer={}) + else: + with self.assertRaises(errors.InvalidArgumentError): + variable_scope.get_variable("x4", initializer={}) + @test_util.run_in_graph_and_eager_modes() def testInitFromNonInitializer(self): - with self.test_session(): - # Test various dtypes with zeros initializer as following: - types = [ - dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32, - dtypes.int64, dtypes.bool - ] - - # Use different variable_name to distinguish various dtypes - for (i, dtype) in enumerate(types): - x = variable_scope.get_variable( - name="x%d" % i, shape=(3, 4), dtype=dtype) - y = variable_scope.get_variable( - name="y%d" % i, - shape=(3, 4), - dtype=dtype, - initializer=init_ops.zeros_initializer(dtype=dtype)) - - variables_lib.global_variables_initializer().run() - self.assertAllEqual(x.eval(), y.eval()) - + # Test various dtypes with zeros initializer as following: + types = [ + dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32, + dtypes.int64, dtypes.bool + ] + + # Use different variable_name to distinguish various dtypes + for (i, dtype) in enumerate(types): + x = variable_scope.get_variable( + name="xx%d" % i, shape=(3, 4), dtype=dtype) + y = variable_scope.get_variable( + name="yy%d" % i, + shape=(3, 4), + dtype=dtype, + initializer=init_ops.zeros_initializer(dtype=dtype)) + + self.evaluate(variables_lib.global_variables_initializer()) + self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value())) + + # TODO(alive): support variable partitioning/caching in eager mode. def testVarScopeCachingDevice(self): with self.test_session(): caching_device = "/job:moo" @@ -172,74 +186,74 @@ class VariableScopeTest(test.TestCase): v_tower = variable_scope.get_variable("v", []) self.assertFalse(v_tower.value().device.startswith(caching_device)) + @test_util.run_in_graph_and_eager_modes() def testVarScopeRegularizer(self): - with self.test_session() as sess: - init = init_ops.constant_initializer(0.3) - - def regularizer1(v): - return math_ops.reduce_mean(v) + 0.1 - - def regularizer2(v): - return math_ops.reduce_mean(v) + 0.2 - - with variable_scope.variable_scope( - "tower", regularizer=regularizer1) as tower: - with variable_scope.variable_scope("foo", initializer=init): - v = variable_scope.get_variable("v", []) - sess.run(variables_lib.initialize_variables([v])) - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(1, len(losses)) - self.assertAllClose(losses[0].eval(), 0.4) - with variable_scope.variable_scope(tower, initializer=init) as vs: - u = variable_scope.get_variable("u", []) - vs.set_regularizer(regularizer2) - w = variable_scope.get_variable("w", []) - # Next 3 variable not regularized to test disabling regularization. - x = variable_scope.get_variable( - "x", [], regularizer=variable_scope.no_regularizer) - with variable_scope.variable_scope( - "baz", regularizer=variable_scope.no_regularizer): - y = variable_scope.get_variable("y", []) - vs.set_regularizer(variable_scope.no_regularizer) - z = variable_scope.get_variable("z", []) - # Check results. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(3, len(losses)) - sess.run(variables_lib.initialize_variables([u, w, x, y, z])) - self.assertAllClose(losses[0].eval(), 0.4) - self.assertAllClose(losses[1].eval(), 0.4) - self.assertAllClose(losses[2].eval(), 0.5) - with variable_scope.variable_scope("foo", reuse=True): - v = variable_scope.get_variable("v", - []) # "v" is alredy there, reused - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(3, len(losses)) # No new loss added. + init = init_ops.constant_initializer(0.3) - def testInitializeFromValue(self): - with self.test_session() as sess: - init = constant_op.constant(0.1) - w = variable_scope.get_variable("v", initializer=init) - sess.run(variables_lib.initialize_variables([w])) - self.assertAllClose(w.eval(), 0.1) + def regularizer1(v): + return math_ops.reduce_mean(v) + 0.1 - with self.assertRaisesRegexp(ValueError, "shape"): - # We disallow explicit shape specification when initializer is constant. - variable_scope.get_variable("u", [1], initializer=init) + def regularizer2(v): + return math_ops.reduce_mean(v) + 0.2 + with variable_scope.variable_scope( + "tower3", regularizer=regularizer1) as tower: with variable_scope.variable_scope("foo", initializer=init): - # Constant initializer can be passed through scopes if needed. - v = variable_scope.get_variable("v") - sess.run(variables_lib.initialize_variables([v])) - self.assertAllClose(v.eval(), 0.1) - - # Check that non-float32 initializer creates a non-float32 variable. - init = constant_op.constant(1, dtype=dtypes.int32) - t = variable_scope.get_variable("t", initializer=init) - self.assertEqual(t.dtype.base_dtype, dtypes.int32) - - # Raise error if `initializer` dtype and `dtype` are not identical. - with self.assertRaisesRegexp(ValueError, "don't match"): - variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64) + v = variable_scope.get_variable("v", []) + self.evaluate(variables_lib.variables_initializer([v])) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(1, len(losses)) + self.assertAllClose(self.evaluate(losses[0]), 0.4) + with variable_scope.variable_scope(tower, initializer=init) as vs: + u = variable_scope.get_variable("u", []) + vs.set_regularizer(regularizer2) + w = variable_scope.get_variable("w", []) + # Next 3 variable not regularized to test disabling regularization. + x = variable_scope.get_variable( + "x", [], regularizer=variable_scope.no_regularizer) + with variable_scope.variable_scope( + "baz", regularizer=variable_scope.no_regularizer): + y = variable_scope.get_variable("y", []) + vs.set_regularizer(variable_scope.no_regularizer) + z = variable_scope.get_variable("z", []) + # Check results. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(3, len(losses)) + self.evaluate(variables_lib.variables_initializer([u, w, x, y, z])) + self.assertAllClose(self.evaluate(losses[0]), 0.4) + self.assertAllClose(self.evaluate(losses[1]), 0.4) + self.assertAllClose(self.evaluate(losses[2]), 0.5) + with variable_scope.variable_scope("foo", reuse=True): + v = variable_scope.get_variable("v", + []) # "v" is alredy there, reused + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(3, len(losses)) # No new loss added. + + @test_util.run_in_graph_and_eager_modes() + def testInitializeFromValue(self): + init = constant_op.constant(0.1) + w = variable_scope.get_variable("v", initializer=init) + self.evaluate(variables_lib.variables_initializer([w])) + self.assertAllClose(self.evaluate(w.value()), 0.1) + + with self.assertRaisesRegexp(ValueError, "shape"): + # We disallow explicit shape specification when initializer is constant. + variable_scope.get_variable("u", [1], initializer=init) + + with variable_scope.variable_scope("foo", initializer=init): + # Constant initializer can be passed through scopes if needed. + v = variable_scope.get_variable("v") + self.evaluate(variables_lib.variables_initializer([v])) + self.assertAllClose(self.evaluate(v.value()), 0.1) + + # Check that non-float32 initializer creates a non-float32 variable. + init = constant_op.constant(1, dtype=dtypes.int32) + t = variable_scope.get_variable("t", initializer=init) + self.assertEqual(t.dtype.base_dtype, dtypes.int32) + + # Raise error if `initializer` dtype and `dtype` are not identical. + with self.assertRaisesRegexp(ValueError, "don't match"): + variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64) def testControlDeps(self): with self.test_session() as sess: @@ -250,16 +264,16 @@ class VariableScopeTest(test.TestCase): "v1", [1], initializer=init_ops.constant_initializer(1)) add = v1 + v0 # v0 should be uninitialized. - with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): + with self.assertRaisesRegexp(errors.OpError, "uninitialized"): sess.run(v0) # We should be able to initialize and run v1 without initializing # v0, even if the variable was created with a control dep on v0. sess.run(v1.initializer) self.assertEqual(1, sess.run(v1)) # v0 should still be uninitialized. - with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): + with self.assertRaisesRegexp(errors.OpError, "uninitialized"): sess.run(v0) - with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): + with self.assertRaisesRegexp(errors.OpError, "uninitialized"): sess.run(add) # If we initialize v0 we should be able to run 'add'. sess.run(v0.initializer) @@ -295,82 +309,85 @@ class VariableScopeTest(test.TestCase): sess.run(v2.initializer) self.assertEqual([2], sess.run(v2)) # v0 should still be uninitialized. - with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): + with self.assertRaisesRegexp(errors.OpError, "uninitialized"): sess.run(v0) # We should not be able to run 'add' yet. - with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): + with self.assertRaisesRegexp(errors.OpError, "uninitialized"): sess.run(add) # If we initialize v0 we should be able to run 'add'. sess.run(v0.initializer) sess.run(add) + @test_util.run_in_graph_and_eager_modes() def testGetVariableScope(self): # Test the get_variable_scope() function and setting properties of result. - with self.test_session() as sess: - init = init_ops.constant_initializer(0.3) - with variable_scope.variable_scope("foo"): - new_init1 = variable_scope.get_variable_scope().initializer - self.assertEqual(new_init1, None) - # Check that we can set initializer like this. - variable_scope.get_variable_scope().set_initializer(init) - v = variable_scope.get_variable("v", []) - sess.run(variables_lib.initialize_variables([v])) - self.assertAllClose(v.eval(), 0.3) + init = init_ops.constant_initializer(0.3) + with variable_scope.variable_scope("bar"): + new_init1 = variable_scope.get_variable_scope().initializer + self.assertEqual(new_init1, None) + # Check that we can set initializer like this. + variable_scope.get_variable_scope().set_initializer(init) + v = variable_scope.get_variable("v", []) + self.evaluate(variables_lib.variables_initializer([v])) + self.assertAllClose(self.evaluate(v.value()), 0.3) + if context.in_graph_mode(): # Check that we can set reuse. variable_scope.get_variable_scope().reuse_variables() with self.assertRaises(ValueError): # Fail, w does not exist yet. variable_scope.get_variable("w", [1]) - # Check that the set initializer goes away. - new_init = variable_scope.get_variable_scope().initializer - self.assertEqual(new_init, None) + # Check that the set initializer goes away. + new_init = variable_scope.get_variable_scope().initializer + self.assertEqual(new_init, None) + @test_util.run_in_graph_and_eager_modes() def testVarScope(self): - with self.test_session(): - with variable_scope.variable_scope("tower") as tower: - self.assertEqual(tower.name, "tower") + with variable_scope.variable_scope("tower4") as tower: + self.assertEqual(tower.name, "tower4") + with ops.name_scope("scope") as sc: + self.assertEqual(sc, "tower4/scope/") + + with variable_scope.variable_scope("tower5"): + with variable_scope.variable_scope("bar") as bar: + self.assertEqual(bar.name, "tower5/bar") with ops.name_scope("scope") as sc: - self.assertEqual(sc, "tower/scope/") + self.assertEqual(sc, "tower5/bar/scope/") - with variable_scope.variable_scope("foo"): - with variable_scope.variable_scope("bar") as bar: - self.assertEqual(bar.name, "foo/bar") - with ops.name_scope("scope") as sc: - self.assertEqual(sc, "foo/bar/scope/") - - with variable_scope.variable_scope("foo"): - with variable_scope.variable_scope(tower, reuse=True) as tower_shared: - self.assertEqual(tower_shared.name, "tower") - with ops.name_scope("scope") as sc: - self.assertEqual(sc, "foo_1/tower/scope/") + with variable_scope.variable_scope("tower6"): + with variable_scope.variable_scope(tower, reuse=True) as tower_shared: + self.assertEqual(tower_shared.name, "tower4") + with ops.name_scope("scope") as sc: + self.assertEqual(sc, "tower6/tower4/scope/") + @test_util.run_in_graph_and_eager_modes() def testVarScopeNameScope(self): - with self.test_session(): - with ops.name_scope("scope1"): - with variable_scope.variable_scope("tower") as tower: - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope1/tower/scope2/") + with ops.name_scope("testVarScopeNameScope1"): + with variable_scope.variable_scope("tower") as tower: + with ops.name_scope("scope2") as sc2: + self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") + if context.in_graph_mode(): with variable_scope.variable_scope( tower): # Re-entering acts like another "tower". with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope1/tower_1/scope2/") + self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/") with variable_scope.variable_scope( "tower"): # Re-entering by string acts the same. with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope1/tower_2/scope2/") + self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/") - with ops.name_scope("scope3"): - with variable_scope.variable_scope("tower"): - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope3/tower/scope2/") + with ops.name_scope("testVarScopeNameScope2"): + with variable_scope.variable_scope("tower"): + with ops.name_scope("scope2") as sc2: + self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") + if context.in_graph_mode(): with variable_scope.variable_scope(tower): with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope3/tower_1/scope2/") + self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") - root_var_scope = variable_scope.get_variable_scope() - with ops.name_scope("scope4"): - with variable_scope.variable_scope(root_var_scope): - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope4/scope2/") + root_var_scope = variable_scope.get_variable_scope() + with ops.name_scope("testVarScopeNameScope3"): + with variable_scope.variable_scope(root_var_scope): + with ops.name_scope("scope2") as sc2: + self.assertEqual(sc2, "testVarScopeNameScope3/scope2/") def testVarScopeOriginalNameScope(self): with self.test_session(): @@ -422,51 +439,46 @@ class VariableScopeTest(test.TestCase): with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: self.assertFalse(jump_no_reuse.reuse) + @test_util.run_in_graph_and_eager_modes() def testVarScopeGetOrCreateReuse(self): - x = array_ops.placeholder(dtypes.float32) - - with variable_scope.variable_scope("bar", - reuse=variable_scope.AUTO_REUSE): - v_assign = state_ops.assign(variable_scope.get_variable("var", []), x) - - with variable_scope.variable_scope("bar", - reuse=variable_scope.AUTO_REUSE): - v = variable_scope.get_variable("var", []) - - with self.test_session() as sess: - def test_value(value): - sess.run(v_assign, feed_dict={x: value}) - self.assertEqual(value, v.eval()) - - test_value(42) # Variable is created. - test_value(13) # Variable is reused hereafter. - test_value(17) + def test_value(value): + x = constant_op.constant(value) + with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar", + reuse=variable_scope.AUTO_REUSE): + _ = state_ops.assign(variable_scope.get_variable("var", []), x) + with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar", + reuse=variable_scope.AUTO_REUSE): + _ = variable_scope.get_variable("var", []) + self.assertEqual(value, self.evaluate(x)) + test_value(42.) # Variable is created. + test_value(13.) # Variable is reused hereafter. + test_value(17.) def testVarOpScope(self): with self.test_session(): - with ops.name_scope("scope1"): + with ops.name_scope("testVarOpScope1"): with variable_scope.variable_scope("tower", "default", []): self.assertEqual( variable_scope.get_variable("w", []).name, "tower/w:0") - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope1/tower/scope2/") + with ops.name_scope("testVarOpScope2") as sc2: + self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/") with variable_scope.variable_scope("tower", "default", []): with self.assertRaises(ValueError): variable_scope.get_variable("w", []) - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope1/tower_1/scope2/") + with ops.name_scope("testVarOpScope2") as sc2: + self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/") - with ops.name_scope("scope2"): + with ops.name_scope("testVarOpScope2"): with variable_scope.variable_scope(None, "default", []): self.assertEqual( variable_scope.get_variable("w", []).name, "default/w:0") - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope2/default/scope2/") + with ops.name_scope("testVarOpScope2") as sc2: + self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/") with variable_scope.variable_scope(None, "default", []): self.assertEqual( variable_scope.get_variable("w", []).name, "default_1/w:0") - with ops.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope2/default_1/scope2/") + with ops.name_scope("testVarOpScope2") as sc2: + self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/") def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self): with self.test_session(): @@ -714,27 +726,27 @@ class VariableScopeTest(test.TestCase): with ops.name_scope("scope2") as sc2: self.assertEqual(sc2, "outer_1/default/scope2/") + @test_util.run_in_graph_and_eager_modes() def testGetLocalVar(self): - with self.test_session(): - # Check that local variable respects naming. - with variable_scope.variable_scope("outer") as outer: - with variable_scope.variable_scope(outer, "default", []): - local_var = variable_scope.get_local_variable( - "w", [], collections=["foo"]) - self.assertEqual(local_var.name, "outer/w:0") - - # Since variable is local, it should be in the local variable collection - # but not the trainable collection. - self.assertIn(local_var, - ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) - self.assertIn(local_var, ops.get_collection("foo")) - self.assertNotIn(local_var, - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - - # Check that local variable respects `reuse`. - with variable_scope.variable_scope(outer, "default", reuse=True): - self.assertEqual( - variable_scope.get_local_variable("w", []).name, "outer/w:0") + # Check that local variable respects naming. + with variable_scope.variable_scope("outer") as outer: + with variable_scope.variable_scope(outer, "default", []): + local_var = variable_scope.get_local_variable( + "w", [], collections=["foo"]) + self.assertEqual(local_var.name, "outer/w:0") + + # Since variable is local, it should be in the local variable collection + # but not the trainable collection. + self.assertIn(local_var, + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + self.assertIn(local_var, ops.get_collection("foo")) + self.assertNotIn(local_var, + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + + # Check that local variable respects `reuse`. + with variable_scope.variable_scope(outer, "default", reuse=True): + self.assertEqual( + variable_scope.get_local_variable("w", []).name, "outer/w:0") def testGetVarWithDevice(self): g = ops.Graph() @@ -753,69 +765,93 @@ class VariableScopeTest(test.TestCase): self.assertEqual(varname_type[0], ("x", dtypes.float32)) self.assertEqual(varname_type[1], ("y", dtypes.int64)) + @test_util.run_in_graph_and_eager_modes() def testGetCollection(self): - with self.test_session(): - _ = variable_scope.get_variable("a", []) - _ = variable_scope.get_variable("b", [], trainable=False) - with variable_scope.variable_scope("foo_") as scope1: - _ = variable_scope.get_variable("a", []) - _ = variable_scope.get_variable("b", [], trainable=False) - self.assertEqual([ - v.name - for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], ["foo_/a:0"]) - self.assertEqual([ - v.name - for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - ], ["foo_/a:0", "foo_/b:0"]) - with variable_scope.variable_scope("foo") as scope2: - _ = variable_scope.get_variable("a", []) - _ = variable_scope.get_variable("b", [], trainable=False) - self.assertEqual([ - v.name - for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], ["foo/a:0"]) - self.assertEqual([ - v.name - for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - ], ["foo/a:0", "foo/b:0"]) - scope = variable_scope.get_variable_scope() + _ = variable_scope.get_variable("testGetCollection_a", []) + _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False) + with variable_scope.variable_scope("testGetCollection_foo_") as scope1: + _ = variable_scope.get_variable("testGetCollection_a", []) + _ = variable_scope.get_variable("testGetCollection_b", [], + trainable=False) self.assertEqual([ - v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - ], ["a:0", "b:0", "foo_/a:0", "foo_/b:0", "foo/a:0", "foo/b:0"]) + v.name + for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + ], ["testGetCollection_foo_/testGetCollection_a:0"]) self.assertEqual([ v.name - for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], ["a:0", "foo_/a:0", "foo/a:0"]) - + for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ], [ + "testGetCollection_foo_/testGetCollection_a:0", + "testGetCollection_foo_/testGetCollection_b:0" + ]) + with variable_scope.variable_scope("testGetCollection_foo") as scope2: + _ = variable_scope.get_variable("testGetCollection_a", []) + _ = variable_scope.get_variable("testGetCollection_b", [], + trainable=False) + self.assertEqual([ + v.name + for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + ], ["testGetCollection_foo/testGetCollection_a:0"]) + self.assertEqual([ + v.name + for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ], [ + "testGetCollection_foo/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_b:0" + ]) + scope = variable_scope.get_variable_scope() + self.assertEqual([ + v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ], [ + "testGetCollection_a:0", "testGetCollection_b:0", + "testGetCollection_foo_/testGetCollection_a:0", + "testGetCollection_foo_/testGetCollection_b:0", + "testGetCollection_foo/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_b:0" + ]) + self.assertEqual([ + v.name + for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + ], [ + "testGetCollection_a:0", + "testGetCollection_foo_/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_a:0" + ]) + + @test_util.run_in_graph_and_eager_modes() def testGetTrainableVariables(self): - with self.test_session(): - _ = variable_scope.get_variable("a", []) - with variable_scope.variable_scope("foo") as scope: - _ = variable_scope.get_variable("b", []) - _ = variable_scope.get_variable("c", [], trainable=False) - self.assertEqual([v.name - for v in scope.trainable_variables()], ["foo/b:0"]) - + _ = variable_scope.get_variable("testGetTrainableVariables_a", []) + with variable_scope.variable_scope( + "testGetTrainableVariables_foo") as scope: + _ = variable_scope.get_variable("testGetTrainableVariables_b", []) + _ = variable_scope.get_variable("testGetTrainableVariables_c", [], + trainable=False) + self.assertEqual([v.name + for v in scope.trainable_variables()], + ["testGetTrainableVariables_foo/" + "testGetTrainableVariables_b:0"]) + + @test_util.run_in_graph_and_eager_modes() def testGetGlobalVariables(self): - with self.test_session(): - _ = variable_scope.get_variable("a", []) - with variable_scope.variable_scope("foo") as scope: - _ = variable_scope.get_variable("b", []) - self.assertEqual([v.name - for v in scope.global_variables()], ["foo/b:0"]) - + _ = variable_scope.get_variable("testGetGlobalVariables_a", []) + with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: + _ = variable_scope.get_variable("testGetGlobalVariables_b", []) + self.assertEqual([v.name + for v in scope.global_variables()], + ["testGetGlobalVariables_foo/" + "testGetGlobalVariables_b:0"]) + + @test_util.run_in_graph_and_eager_modes() def testGetLocalVariables(self): - with self.test_session(): + _ = variable_scope.get_variable( + "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) + with variable_scope.variable_scope("foo") as scope: _ = variable_scope.get_variable( - "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) - with variable_scope.variable_scope("foo") as scope: - _ = variable_scope.get_variable( - "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) - _ = variable_scope.get_variable( - "c", []) - self.assertEqual([v.name - for v in scope.local_variables()], ["foo/b:0"]) + "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) + _ = variable_scope.get_variable( + "c", []) + self.assertEqual([v.name + for v in scope.local_variables()], ["foo/b:0"]) def testGetVariableWithRefDtype(self): v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e43cf4bc04..1d747f8400 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -255,8 +255,9 @@ class ResourceVariable(variables.Variable): context.get_default_context().device_name) else: initial_value = initial_value() - initial_value = ops.convert_to_tensor( - initial_value, name="initial_value", dtype=dtype) + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index b7913890e4..9093c12968 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -259,7 +259,8 @@ class _VariableStore(object): applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation - of variables. + of variables. In Eager mode, this argument is always forced to be + tf.AUTO_REUSE. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). collections: List of graph collections keys to add the `Variable` to. @@ -278,6 +279,7 @@ class _VariableStore(object): use_resource: If False, creates a regular Variable. If True, creates instead an experimental ResourceVariable which has well-defined semantics. Defaults to False (will later change to True). + In Eager mode, this argument is always forced to be true. custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. The signature of `custom_getter` should match that of this method, @@ -311,6 +313,10 @@ class _VariableStore(object): raise ValueError( "Passed a custom_getter which is not callable: %s" % custom_getter) + if context.in_eager_mode(): + reuse = AUTO_REUSE + use_resource = True + # If a *_ref type is passed in an error would be triggered further down the # stack. We prevent this using base_dtype to get a non-ref version of the # type, before doing anything else. When _ref types are removed in favor of @@ -498,6 +504,9 @@ class _VariableStore(object): when violating reuse during variable creation, or if an existing sharded variable exists for the given name but with different sharding. """ + if context.in_eager_mode(): + raise NotImplementedError("Partitioned variables are not yet supported " + "in Eager mode.") initializing_from_value = initializer is not None and isinstance( initializer, ops.Tensor) @@ -792,14 +801,19 @@ class _VariableStore(object): # Run the regularizer if requested and save the resulting loss. if regularizer: - with ops.colocate_with(v.op): + with ops.colocate_with(v): with ops.name_scope(name + "/Regularizer/"): loss = regularizer(v) if loss is not None: + if context.in_graph_mode(): + v_name = v.name + loss_name = loss.name + else: + v_name = "v_%s" % type(v) + loss_name = "loss_%s" % type(loss) logging.vlog(1, "Applied regularizer to %s and added the result %s " - "to REGULARIZATION_LOSSES.", v.name, loss.name) + "to REGULARIZATION_LOSSES.", v_name, loss_name) ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) - return v # Initialize variable when no initializer provided @@ -853,7 +867,8 @@ class VariableScope(object): initializer: default initializer passed to get_variable. regularizer: default regularizer passed to get_variable. reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in - get_variable. + get_variable. In Eager mode, this argument is always forced to be + tf.AUTO_REUSE. caching_device: string, callable, or None: the caching device passed to get_variable. partitioner: callable or `None`: the partitioner passed to `get_variable`. @@ -862,7 +877,8 @@ class VariableScope(object): dtype: default type passed to get_variable (defaults to DT_FLOAT). use_resource: if False, create a normal Variable; if True create an experimental ResourceVariable with well-defined semantics. Defaults - to False (will later change to True). + to False (will later change to True). In Eager mode, this argument is + always forced to be True. constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must @@ -903,6 +919,7 @@ class VariableScope(object): if self._partitioner is not None: raise NotImplementedError("Partitioned variables are not yet supported " "in Eager mode.") + self._reuse = AUTO_REUSE self._use_resource = True @property @@ -963,6 +980,8 @@ class VariableScope(object): def set_use_resource(self, use_resource): """Sets whether to use ResourceVariables for this scope.""" + if context.in_eager_mode() and not use_resource: + raise ValueError("In eager mode, use_resource cannot be set to false.") self._use_resource = use_resource def set_regularizer(self, regularizer): @@ -1029,8 +1048,14 @@ class VariableScope(object): partitioner = self._partitioner if custom_getter is None: custom_getter = self._custom_getter - if reuse is None: - reuse = self._reuse + if context.in_graph_mode(): + if reuse is None: + reuse = self._reuse + if use_resource is None: + use_resource = self._use_resource + else: + reuse = AUTO_REUSE + use_resource = True full_name = self.name + "/" + name if self.name else name # Variable names only depend on variable_scope (full_name here), @@ -1050,12 +1075,6 @@ class VariableScope(object): constraint = self._constraint if dtype is None: dtype = self._dtype - if context.in_graph_mode(): - if use_resource is None: - use_resource = self._use_resource - else: - use_resource = True - return var_store.get_variable( full_name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, reuse=reuse, trainable=trainable, @@ -1232,7 +1251,8 @@ Args: must be known. use_resource: If False, creates a regular Variable. If true, creates an experimental ResourceVariable instead with well-defined semantics. - Defaults to False (will later change to True). + Defaults to False (will later change to True). In Eager mode, this argument + is always forced to be True. custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. The signature of `custom_getter` should match that of this method, @@ -1661,12 +1681,14 @@ def variable_scope(name_or_scope, reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create variables if they do not exist, and return them otherwise; if None, we - inherit the parent scope's reuse flag. + inherit the parent scope's reuse flag. In Eager mode, this argument is + always forced to be tf.AUTO_REUSE. dtype: type of variables created in this scope (defaults to the type in the passed scope, or inherited from parent scope). use_resource: If False, all variables will be regular Variables. If True, experimental ResourceVariables with well-defined semantics will be used - instead. Defaults to False (will later change to True). + instead. Defaults to False (will later change to True). In Eager mode, + this argument is always forced to be True. constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must |