aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ali Yahya <alive@google.com>2017-08-29 16:15:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 16:19:03 -0700
commit5a76b8dce388a0a03849979a7eb655d1780cfd88 (patch)
tree6f9b9bc331e5049fd940ff6114c91d98e4f8881f
parent0492dccf024be59f4f38f7650deb9570a2f0c2db (diff)
Modified variable scopes to work with Eager mode.
PiperOrigin-RevId: 166920017
-rw-r--r--tensorflow/python/framework/test_util.py4
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py584
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py5
-rw-r--r--tensorflow/python/ops/variable_scope.py56
6 files changed, 359 insertions, 296 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index be51a38daf..c65816a543 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -277,7 +277,7 @@ def enable_c_api(fn):
def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
use_gpu=False, force_gpu=False,
- reset_test=False):
+ reset_test=True):
"""Runs the test in both graph and eager modes.
Args:
@@ -305,6 +305,8 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
f(self)
if reset_test:
+ # This decorator runs the wrapped test twice.
+ # Reset the test environment between runs.
self.tearDown()
self.setUp()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 2db686333d..bdd9a6dc22 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -932,12 +932,14 @@ tf_py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
],
tags = ["no_windows"],
)
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 98fc568515..c31732d807 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -53,7 +53,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
0,
dtype=dtypes.int32)).run()
- def testReadVariableDtypeMismatch(self):
+ def testReadVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
@@ -62,7 +62,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
"Expected float got int32."):
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
- def testAssignVariableDtypeMismatch(self):
+ def testAssignVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 67932d0823..cdac12f05a 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -20,10 +20,12 @@ from __future__ import print_function
import numpy
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
@@ -37,32 +39,38 @@ from tensorflow.python.platform import test
class VariableScopeTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def testGetVar(self):
vs = variable_scope._get_default_variable_store()
v = vs.get_variable("v", [1])
v1 = vs.get_variable("v", [1])
self.assertEqual(v, v1)
+ @test_util.run_in_graph_and_eager_modes()
def testResource(self):
vs = variable_scope._get_default_variable_store()
v1 = vs.get_variable("v", [1], use_resource=True)
self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
+ @test_util.run_in_graph_and_eager_modes()
def testNameExists(self):
vs = variable_scope._get_default_variable_store()
# No check by default, so we can both create and get existing names.
v = vs.get_variable("v", [1])
v1 = vs.get_variable("v", [1])
self.assertEqual(v, v1)
- # When reuse is False, we fail when variables are already there.
- vs.get_variable("w", [1], reuse=False) # That's ok.
- with self.assertRaises(ValueError):
- vs.get_variable("v", [1], reuse=False) # That fails.
- # When reuse is True, we fail when variables are new.
- vs.get_variable("v", [1], reuse=True) # That's ok.
- with self.assertRaises(ValueError):
- vs.get_variable("u", [1], reuse=True) # That fails.
+ if context.in_graph_mode():
+ # When reuse is False, we fail when variables are already there.
+ vs.get_variable("w", [1], reuse=False) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("v", [1], reuse=False) # That fails.
+ # When reuse is True, we fail when variables are new.
+ vs.get_variable("v", [1], reuse=True) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("u", [1], reuse=True) # That fails.
+
+ @test_util.run_in_graph_and_eager_modes()
def testNamelessStore(self):
vs = variable_scope._get_default_variable_store()
vs.get_variable("v1", [2])
@@ -71,22 +79,23 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(
set(expected_names), set([v.name for v in vs._vars.values()]))
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeInitializer(self):
- with self.test_session() as sess:
- init = init_ops.constant_initializer(0.3)
- with variable_scope.variable_scope("tower") as tower:
- with variable_scope.variable_scope("foo", initializer=init):
- v = variable_scope.get_variable("v", [])
- sess.run(variables_lib.initialize_variables([v]))
- self.assertAllClose(v.eval(), 0.3)
- with variable_scope.variable_scope(tower, initializer=init):
- w = variable_scope.get_variable("w", [])
- sess.run(variables_lib.initialize_variables([w]))
- self.assertAllClose(w.eval(), 0.3)
+ init = init_ops.constant_initializer(0.3)
+ with variable_scope.variable_scope("tower0") as tower:
+ with variable_scope.variable_scope("foo", initializer=init):
+ v = variable_scope.get_variable("v", [])
+ self.evaluate(variables_lib.variables_initializer([v]))
+ self.assertAllClose(self.evaluate(v.value()), 0.3)
+ with variable_scope.variable_scope(tower, initializer=init):
+ w = variable_scope.get_variable("w", [])
+ self.evaluate(variables_lib.variables_initializer([w]))
+ self.assertAllClose(self.evaluate(w.value()), 0.3)
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeConstraint(self):
constraint = lambda x: 0. * x
- with variable_scope.variable_scope("tower") as tower:
+ with variable_scope.variable_scope("tower1") as tower:
with variable_scope.variable_scope("foo", constraint=constraint):
v = variable_scope.get_variable("v", [])
self.assertEqual(v.constraint, constraint)
@@ -94,51 +103,56 @@ class VariableScopeTest(test.TestCase):
w = variable_scope.get_variable("w", [])
self.assertEqual(w.constraint, constraint)
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeDType(self):
- with self.test_session():
- with variable_scope.variable_scope("tower") as tower:
- with variable_scope.variable_scope("foo", dtype=dtypes.float16):
- v = variable_scope.get_variable("v", [])
- self.assertEqual(v.dtype.base_dtype, dtypes.float16)
- with variable_scope.variable_scope(tower, dtype=dtypes.float16):
- w = variable_scope.get_variable("w", [])
- self.assertEqual(w.dtype.base_dtype, dtypes.float16)
+ with variable_scope.variable_scope("tower2") as tower:
+ with variable_scope.variable_scope("foo", dtype=dtypes.float16):
+ v = variable_scope.get_variable("v", [])
+ self.assertEqual(v.dtype.base_dtype, dtypes.float16)
+ with variable_scope.variable_scope(tower, dtype=dtypes.float16):
+ w = variable_scope.get_variable("w", [])
+ self.assertEqual(w.dtype.base_dtype, dtypes.float16)
+ @test_util.run_in_graph_and_eager_modes()
def testInitFromNonTensorValue(self):
- with self.test_session() as sess:
- v = variable_scope.get_variable("v", initializer=4, dtype=dtypes.int32)
- sess.run(variables_lib.initialize_variables([v]))
- self.assertAllClose(v.eval(), 4)
+ v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
+ self.evaluate(variables_lib.variables_initializer([v]))
+ self.assertAllClose(self.evaluate(v.value()), 4)
- w = variable_scope.get_variable(
- "w", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
- sess.run(variables_lib.initialize_variables([w]))
- self.assertAllClose(w.eval(), [1, 2, 3])
+ w = variable_scope.get_variable(
+ "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
+ self.evaluate(variables_lib.variables_initializer([w]))
+ self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
+ if context.in_graph_mode():
with self.assertRaises(TypeError):
- variable_scope.get_variable("x", initializer={})
+ variable_scope.get_variable("x4", initializer={})
+ else:
+ with self.assertRaises(errors.InvalidArgumentError):
+ variable_scope.get_variable("x4", initializer={})
+ @test_util.run_in_graph_and_eager_modes()
def testInitFromNonInitializer(self):
- with self.test_session():
- # Test various dtypes with zeros initializer as following:
- types = [
- dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
- dtypes.int64, dtypes.bool
- ]
-
- # Use different variable_name to distinguish various dtypes
- for (i, dtype) in enumerate(types):
- x = variable_scope.get_variable(
- name="x%d" % i, shape=(3, 4), dtype=dtype)
- y = variable_scope.get_variable(
- name="y%d" % i,
- shape=(3, 4),
- dtype=dtype,
- initializer=init_ops.zeros_initializer(dtype=dtype))
-
- variables_lib.global_variables_initializer().run()
- self.assertAllEqual(x.eval(), y.eval())
-
+ # Test various dtypes with zeros initializer as following:
+ types = [
+ dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
+ dtypes.int64, dtypes.bool
+ ]
+
+ # Use different variable_name to distinguish various dtypes
+ for (i, dtype) in enumerate(types):
+ x = variable_scope.get_variable(
+ name="xx%d" % i, shape=(3, 4), dtype=dtype)
+ y = variable_scope.get_variable(
+ name="yy%d" % i,
+ shape=(3, 4),
+ dtype=dtype,
+ initializer=init_ops.zeros_initializer(dtype=dtype))
+
+ self.evaluate(variables_lib.global_variables_initializer())
+ self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
+
+ # TODO(alive): support variable partitioning/caching in eager mode.
def testVarScopeCachingDevice(self):
with self.test_session():
caching_device = "/job:moo"
@@ -172,74 +186,74 @@ class VariableScopeTest(test.TestCase):
v_tower = variable_scope.get_variable("v", [])
self.assertFalse(v_tower.value().device.startswith(caching_device))
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeRegularizer(self):
- with self.test_session() as sess:
- init = init_ops.constant_initializer(0.3)
-
- def regularizer1(v):
- return math_ops.reduce_mean(v) + 0.1
-
- def regularizer2(v):
- return math_ops.reduce_mean(v) + 0.2
-
- with variable_scope.variable_scope(
- "tower", regularizer=regularizer1) as tower:
- with variable_scope.variable_scope("foo", initializer=init):
- v = variable_scope.get_variable("v", [])
- sess.run(variables_lib.initialize_variables([v]))
- losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(1, len(losses))
- self.assertAllClose(losses[0].eval(), 0.4)
- with variable_scope.variable_scope(tower, initializer=init) as vs:
- u = variable_scope.get_variable("u", [])
- vs.set_regularizer(regularizer2)
- w = variable_scope.get_variable("w", [])
- # Next 3 variable not regularized to test disabling regularization.
- x = variable_scope.get_variable(
- "x", [], regularizer=variable_scope.no_regularizer)
- with variable_scope.variable_scope(
- "baz", regularizer=variable_scope.no_regularizer):
- y = variable_scope.get_variable("y", [])
- vs.set_regularizer(variable_scope.no_regularizer)
- z = variable_scope.get_variable("z", [])
- # Check results.
- losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(3, len(losses))
- sess.run(variables_lib.initialize_variables([u, w, x, y, z]))
- self.assertAllClose(losses[0].eval(), 0.4)
- self.assertAllClose(losses[1].eval(), 0.4)
- self.assertAllClose(losses[2].eval(), 0.5)
- with variable_scope.variable_scope("foo", reuse=True):
- v = variable_scope.get_variable("v",
- []) # "v" is alredy there, reused
- losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(3, len(losses)) # No new loss added.
+ init = init_ops.constant_initializer(0.3)
- def testInitializeFromValue(self):
- with self.test_session() as sess:
- init = constant_op.constant(0.1)
- w = variable_scope.get_variable("v", initializer=init)
- sess.run(variables_lib.initialize_variables([w]))
- self.assertAllClose(w.eval(), 0.1)
+ def regularizer1(v):
+ return math_ops.reduce_mean(v) + 0.1
- with self.assertRaisesRegexp(ValueError, "shape"):
- # We disallow explicit shape specification when initializer is constant.
- variable_scope.get_variable("u", [1], initializer=init)
+ def regularizer2(v):
+ return math_ops.reduce_mean(v) + 0.2
+ with variable_scope.variable_scope(
+ "tower3", regularizer=regularizer1) as tower:
with variable_scope.variable_scope("foo", initializer=init):
- # Constant initializer can be passed through scopes if needed.
- v = variable_scope.get_variable("v")
- sess.run(variables_lib.initialize_variables([v]))
- self.assertAllClose(v.eval(), 0.1)
-
- # Check that non-float32 initializer creates a non-float32 variable.
- init = constant_op.constant(1, dtype=dtypes.int32)
- t = variable_scope.get_variable("t", initializer=init)
- self.assertEqual(t.dtype.base_dtype, dtypes.int32)
-
- # Raise error if `initializer` dtype and `dtype` are not identical.
- with self.assertRaisesRegexp(ValueError, "don't match"):
- variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
+ v = variable_scope.get_variable("v", [])
+ self.evaluate(variables_lib.variables_initializer([v]))
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(1, len(losses))
+ self.assertAllClose(self.evaluate(losses[0]), 0.4)
+ with variable_scope.variable_scope(tower, initializer=init) as vs:
+ u = variable_scope.get_variable("u", [])
+ vs.set_regularizer(regularizer2)
+ w = variable_scope.get_variable("w", [])
+ # Next 3 variable not regularized to test disabling regularization.
+ x = variable_scope.get_variable(
+ "x", [], regularizer=variable_scope.no_regularizer)
+ with variable_scope.variable_scope(
+ "baz", regularizer=variable_scope.no_regularizer):
+ y = variable_scope.get_variable("y", [])
+ vs.set_regularizer(variable_scope.no_regularizer)
+ z = variable_scope.get_variable("z", [])
+ # Check results.
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(3, len(losses))
+ self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
+ self.assertAllClose(self.evaluate(losses[0]), 0.4)
+ self.assertAllClose(self.evaluate(losses[1]), 0.4)
+ self.assertAllClose(self.evaluate(losses[2]), 0.5)
+ with variable_scope.variable_scope("foo", reuse=True):
+ v = variable_scope.get_variable("v",
+ []) # "v" is alredy there, reused
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(3, len(losses)) # No new loss added.
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInitializeFromValue(self):
+ init = constant_op.constant(0.1)
+ w = variable_scope.get_variable("v", initializer=init)
+ self.evaluate(variables_lib.variables_initializer([w]))
+ self.assertAllClose(self.evaluate(w.value()), 0.1)
+
+ with self.assertRaisesRegexp(ValueError, "shape"):
+ # We disallow explicit shape specification when initializer is constant.
+ variable_scope.get_variable("u", [1], initializer=init)
+
+ with variable_scope.variable_scope("foo", initializer=init):
+ # Constant initializer can be passed through scopes if needed.
+ v = variable_scope.get_variable("v")
+ self.evaluate(variables_lib.variables_initializer([v]))
+ self.assertAllClose(self.evaluate(v.value()), 0.1)
+
+ # Check that non-float32 initializer creates a non-float32 variable.
+ init = constant_op.constant(1, dtype=dtypes.int32)
+ t = variable_scope.get_variable("t", initializer=init)
+ self.assertEqual(t.dtype.base_dtype, dtypes.int32)
+
+ # Raise error if `initializer` dtype and `dtype` are not identical.
+ with self.assertRaisesRegexp(ValueError, "don't match"):
+ variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
def testControlDeps(self):
with self.test_session() as sess:
@@ -250,16 +264,16 @@ class VariableScopeTest(test.TestCase):
"v1", [1], initializer=init_ops.constant_initializer(1))
add = v1 + v0
# v0 should be uninitialized.
- with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+ with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
sess.run(v0)
# We should be able to initialize and run v1 without initializing
# v0, even if the variable was created with a control dep on v0.
sess.run(v1.initializer)
self.assertEqual(1, sess.run(v1))
# v0 should still be uninitialized.
- with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+ with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
sess.run(v0)
- with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+ with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
sess.run(add)
# If we initialize v0 we should be able to run 'add'.
sess.run(v0.initializer)
@@ -295,82 +309,85 @@ class VariableScopeTest(test.TestCase):
sess.run(v2.initializer)
self.assertEqual([2], sess.run(v2))
# v0 should still be uninitialized.
- with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+ with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
sess.run(v0)
# We should not be able to run 'add' yet.
- with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+ with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
sess.run(add)
# If we initialize v0 we should be able to run 'add'.
sess.run(v0.initializer)
sess.run(add)
+ @test_util.run_in_graph_and_eager_modes()
def testGetVariableScope(self):
# Test the get_variable_scope() function and setting properties of result.
- with self.test_session() as sess:
- init = init_ops.constant_initializer(0.3)
- with variable_scope.variable_scope("foo"):
- new_init1 = variable_scope.get_variable_scope().initializer
- self.assertEqual(new_init1, None)
- # Check that we can set initializer like this.
- variable_scope.get_variable_scope().set_initializer(init)
- v = variable_scope.get_variable("v", [])
- sess.run(variables_lib.initialize_variables([v]))
- self.assertAllClose(v.eval(), 0.3)
+ init = init_ops.constant_initializer(0.3)
+ with variable_scope.variable_scope("bar"):
+ new_init1 = variable_scope.get_variable_scope().initializer
+ self.assertEqual(new_init1, None)
+ # Check that we can set initializer like this.
+ variable_scope.get_variable_scope().set_initializer(init)
+ v = variable_scope.get_variable("v", [])
+ self.evaluate(variables_lib.variables_initializer([v]))
+ self.assertAllClose(self.evaluate(v.value()), 0.3)
+ if context.in_graph_mode():
# Check that we can set reuse.
variable_scope.get_variable_scope().reuse_variables()
with self.assertRaises(ValueError): # Fail, w does not exist yet.
variable_scope.get_variable("w", [1])
- # Check that the set initializer goes away.
- new_init = variable_scope.get_variable_scope().initializer
- self.assertEqual(new_init, None)
+ # Check that the set initializer goes away.
+ new_init = variable_scope.get_variable_scope().initializer
+ self.assertEqual(new_init, None)
+ @test_util.run_in_graph_and_eager_modes()
def testVarScope(self):
- with self.test_session():
- with variable_scope.variable_scope("tower") as tower:
- self.assertEqual(tower.name, "tower")
+ with variable_scope.variable_scope("tower4") as tower:
+ self.assertEqual(tower.name, "tower4")
+ with ops.name_scope("scope") as sc:
+ self.assertEqual(sc, "tower4/scope/")
+
+ with variable_scope.variable_scope("tower5"):
+ with variable_scope.variable_scope("bar") as bar:
+ self.assertEqual(bar.name, "tower5/bar")
with ops.name_scope("scope") as sc:
- self.assertEqual(sc, "tower/scope/")
+ self.assertEqual(sc, "tower5/bar/scope/")
- with variable_scope.variable_scope("foo"):
- with variable_scope.variable_scope("bar") as bar:
- self.assertEqual(bar.name, "foo/bar")
- with ops.name_scope("scope") as sc:
- self.assertEqual(sc, "foo/bar/scope/")
-
- with variable_scope.variable_scope("foo"):
- with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
- self.assertEqual(tower_shared.name, "tower")
- with ops.name_scope("scope") as sc:
- self.assertEqual(sc, "foo_1/tower/scope/")
+ with variable_scope.variable_scope("tower6"):
+ with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
+ self.assertEqual(tower_shared.name, "tower4")
+ with ops.name_scope("scope") as sc:
+ self.assertEqual(sc, "tower6/tower4/scope/")
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeNameScope(self):
- with self.test_session():
- with ops.name_scope("scope1"):
- with variable_scope.variable_scope("tower") as tower:
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope1/tower/scope2/")
+ with ops.name_scope("testVarScopeNameScope1"):
+ with variable_scope.variable_scope("tower") as tower:
+ with ops.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
+ if context.in_graph_mode():
with variable_scope.variable_scope(
tower): # Re-entering acts like another "tower".
with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope1/tower_1/scope2/")
+ self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
with variable_scope.variable_scope(
"tower"): # Re-entering by string acts the same.
with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope1/tower_2/scope2/")
+ self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
- with ops.name_scope("scope3"):
- with variable_scope.variable_scope("tower"):
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope3/tower/scope2/")
+ with ops.name_scope("testVarScopeNameScope2"):
+ with variable_scope.variable_scope("tower"):
+ with ops.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
+ if context.in_graph_mode():
with variable_scope.variable_scope(tower):
with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope3/tower_1/scope2/")
+ self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
- root_var_scope = variable_scope.get_variable_scope()
- with ops.name_scope("scope4"):
- with variable_scope.variable_scope(root_var_scope):
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope4/scope2/")
+ root_var_scope = variable_scope.get_variable_scope()
+ with ops.name_scope("testVarScopeNameScope3"):
+ with variable_scope.variable_scope(root_var_scope):
+ with ops.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
def testVarScopeOriginalNameScope(self):
with self.test_session():
@@ -422,51 +439,46 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
self.assertFalse(jump_no_reuse.reuse)
+ @test_util.run_in_graph_and_eager_modes()
def testVarScopeGetOrCreateReuse(self):
- x = array_ops.placeholder(dtypes.float32)
-
- with variable_scope.variable_scope("bar",
- reuse=variable_scope.AUTO_REUSE):
- v_assign = state_ops.assign(variable_scope.get_variable("var", []), x)
-
- with variable_scope.variable_scope("bar",
- reuse=variable_scope.AUTO_REUSE):
- v = variable_scope.get_variable("var", [])
-
- with self.test_session() as sess:
- def test_value(value):
- sess.run(v_assign, feed_dict={x: value})
- self.assertEqual(value, v.eval())
-
- test_value(42) # Variable is created.
- test_value(13) # Variable is reused hereafter.
- test_value(17)
+ def test_value(value):
+ x = constant_op.constant(value)
+ with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+ reuse=variable_scope.AUTO_REUSE):
+ _ = state_ops.assign(variable_scope.get_variable("var", []), x)
+ with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+ reuse=variable_scope.AUTO_REUSE):
+ _ = variable_scope.get_variable("var", [])
+ self.assertEqual(value, self.evaluate(x))
+ test_value(42.) # Variable is created.
+ test_value(13.) # Variable is reused hereafter.
+ test_value(17.)
def testVarOpScope(self):
with self.test_session():
- with ops.name_scope("scope1"):
+ with ops.name_scope("testVarOpScope1"):
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
variable_scope.get_variable("w", []).name, "tower/w:0")
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope1/tower/scope2/")
+ with ops.name_scope("testVarOpScope2") as sc2:
+ self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
with variable_scope.variable_scope("tower", "default", []):
with self.assertRaises(ValueError):
variable_scope.get_variable("w", [])
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope1/tower_1/scope2/")
+ with ops.name_scope("testVarOpScope2") as sc2:
+ self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
- with ops.name_scope("scope2"):
+ with ops.name_scope("testVarOpScope2"):
with variable_scope.variable_scope(None, "default", []):
self.assertEqual(
variable_scope.get_variable("w", []).name, "default/w:0")
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope2/default/scope2/")
+ with ops.name_scope("testVarOpScope2") as sc2:
+ self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
with variable_scope.variable_scope(None, "default", []):
self.assertEqual(
variable_scope.get_variable("w", []).name, "default_1/w:0")
- with ops.name_scope("scope2") as sc2:
- self.assertEqual(sc2, "scope2/default_1/scope2/")
+ with ops.name_scope("testVarOpScope2") as sc2:
+ self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
with self.test_session():
@@ -714,27 +726,27 @@ class VariableScopeTest(test.TestCase):
with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer_1/default/scope2/")
+ @test_util.run_in_graph_and_eager_modes()
def testGetLocalVar(self):
- with self.test_session():
- # Check that local variable respects naming.
- with variable_scope.variable_scope("outer") as outer:
- with variable_scope.variable_scope(outer, "default", []):
- local_var = variable_scope.get_local_variable(
- "w", [], collections=["foo"])
- self.assertEqual(local_var.name, "outer/w:0")
-
- # Since variable is local, it should be in the local variable collection
- # but not the trainable collection.
- self.assertIn(local_var,
- ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
- self.assertIn(local_var, ops.get_collection("foo"))
- self.assertNotIn(local_var,
- ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
-
- # Check that local variable respects `reuse`.
- with variable_scope.variable_scope(outer, "default", reuse=True):
- self.assertEqual(
- variable_scope.get_local_variable("w", []).name, "outer/w:0")
+ # Check that local variable respects naming.
+ with variable_scope.variable_scope("outer") as outer:
+ with variable_scope.variable_scope(outer, "default", []):
+ local_var = variable_scope.get_local_variable(
+ "w", [], collections=["foo"])
+ self.assertEqual(local_var.name, "outer/w:0")
+
+ # Since variable is local, it should be in the local variable collection
+ # but not the trainable collection.
+ self.assertIn(local_var,
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ self.assertIn(local_var, ops.get_collection("foo"))
+ self.assertNotIn(local_var,
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+
+ # Check that local variable respects `reuse`.
+ with variable_scope.variable_scope(outer, "default", reuse=True):
+ self.assertEqual(
+ variable_scope.get_local_variable("w", []).name, "outer/w:0")
def testGetVarWithDevice(self):
g = ops.Graph()
@@ -753,69 +765,93 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(varname_type[0], ("x", dtypes.float32))
self.assertEqual(varname_type[1], ("y", dtypes.int64))
+ @test_util.run_in_graph_and_eager_modes()
def testGetCollection(self):
- with self.test_session():
- _ = variable_scope.get_variable("a", [])
- _ = variable_scope.get_variable("b", [], trainable=False)
- with variable_scope.variable_scope("foo_") as scope1:
- _ = variable_scope.get_variable("a", [])
- _ = variable_scope.get_variable("b", [], trainable=False)
- self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["foo_/a:0"])
- self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- ], ["foo_/a:0", "foo_/b:0"])
- with variable_scope.variable_scope("foo") as scope2:
- _ = variable_scope.get_variable("a", [])
- _ = variable_scope.get_variable("b", [], trainable=False)
- self.assertEqual([
- v.name
- for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["foo/a:0"])
- self.assertEqual([
- v.name
- for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- ], ["foo/a:0", "foo/b:0"])
- scope = variable_scope.get_variable_scope()
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False)
+ with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
self.assertEqual([
- v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- ], ["a:0", "b:0", "foo_/a:0", "foo_/b:0", "foo/a:0", "foo/b:0"])
+ v.name
+ for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo_/testGetCollection_a:0"])
self.assertEqual([
v.name
- for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["a:0", "foo_/a:0", "foo/a:0"])
-
+ for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_b:0"
+ ])
+ with variable_scope.variable_scope("testGetCollection_foo") as scope2:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo/testGetCollection_a:0"])
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
+ ])
+ scope = variable_scope.get_variable_scope()
+ self.assertEqual([
+ v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_a:0", "testGetCollection_b:0",
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_b:0",
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
+ ])
+ self.assertEqual([
+ v.name
+ for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], [
+ "testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_a:0"
+ ])
+
+ @test_util.run_in_graph_and_eager_modes()
def testGetTrainableVariables(self):
- with self.test_session():
- _ = variable_scope.get_variable("a", [])
- with variable_scope.variable_scope("foo") as scope:
- _ = variable_scope.get_variable("b", [])
- _ = variable_scope.get_variable("c", [], trainable=False)
- self.assertEqual([v.name
- for v in scope.trainable_variables()], ["foo/b:0"])
-
+ _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
+ _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
+ trainable=False)
+ self.assertEqual([v.name
+ for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/"
+ "testGetTrainableVariables_b:0"])
+
+ @test_util.run_in_graph_and_eager_modes()
def testGetGlobalVariables(self):
- with self.test_session():
- _ = variable_scope.get_variable("a", [])
- with variable_scope.variable_scope("foo") as scope:
- _ = variable_scope.get_variable("b", [])
- self.assertEqual([v.name
- for v in scope.global_variables()], ["foo/b:0"])
-
+ _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
+ with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
+ self.assertEqual([v.name
+ for v in scope.global_variables()],
+ ["testGetGlobalVariables_foo/"
+ "testGetGlobalVariables_b:0"])
+
+ @test_util.run_in_graph_and_eager_modes()
def testGetLocalVariables(self):
- with self.test_session():
+ _ = variable_scope.get_variable(
+ "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ with variable_scope.variable_scope("foo") as scope:
_ = variable_scope.get_variable(
- "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
- with variable_scope.variable_scope("foo") as scope:
- _ = variable_scope.get_variable(
- "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
- _ = variable_scope.get_variable(
- "c", [])
- self.assertEqual([v.name
- for v in scope.local_variables()], ["foo/b:0"])
+ "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ _ = variable_scope.get_variable(
+ "c", [])
+ self.assertEqual([v.name
+ for v in scope.local_variables()], ["foo/b:0"])
def testGetVariableWithRefDtype(self):
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e43cf4bc04..1d747f8400 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -255,8 +255,9 @@ class ResourceVariable(variables.Variable):
context.get_default_context().device_name)
else:
initial_value = initial_value()
- initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
self._handle = gen_resource_variable_ops.var_handle_op(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index b7913890e4..9093c12968 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -259,7 +259,8 @@ class _VariableStore(object):
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
- of variables.
+ of variables. In Eager mode, this argument is always forced to be
+ tf.AUTO_REUSE.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
collections: List of graph collections keys to add the `Variable` to.
@@ -278,6 +279,7 @@ class _VariableStore(object):
use_resource: If False, creates a regular Variable. If True, creates
instead an experimental ResourceVariable which has well-defined
semantics. Defaults to False (will later change to True).
+ In Eager mode, this argument is always forced to be true.
custom_getter: Callable that takes as a first argument the true getter,
and allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
@@ -311,6 +313,10 @@ class _VariableStore(object):
raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter)
+ if context.in_eager_mode():
+ reuse = AUTO_REUSE
+ use_resource = True
+
# If a *_ref type is passed in an error would be triggered further down the
# stack. We prevent this using base_dtype to get a non-ref version of the
# type, before doing anything else. When _ref types are removed in favor of
@@ -498,6 +504,9 @@ class _VariableStore(object):
when violating reuse during variable creation, or if an existing
sharded variable exists for the given name but with different sharding.
"""
+ if context.in_eager_mode():
+ raise NotImplementedError("Partitioned variables are not yet supported "
+ "in Eager mode.")
initializing_from_value = initializer is not None and isinstance(
initializer, ops.Tensor)
@@ -792,14 +801,19 @@ class _VariableStore(object):
# Run the regularizer if requested and save the resulting loss.
if regularizer:
- with ops.colocate_with(v.op):
+ with ops.colocate_with(v):
with ops.name_scope(name + "/Regularizer/"):
loss = regularizer(v)
if loss is not None:
+ if context.in_graph_mode():
+ v_name = v.name
+ loss_name = loss.name
+ else:
+ v_name = "v_%s" % type(v)
+ loss_name = "loss_%s" % type(loss)
logging.vlog(1, "Applied regularizer to %s and added the result %s "
- "to REGULARIZATION_LOSSES.", v.name, loss.name)
+ "to REGULARIZATION_LOSSES.", v_name, loss_name)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
-
return v
# Initialize variable when no initializer provided
@@ -853,7 +867,8 @@ class VariableScope(object):
initializer: default initializer passed to get_variable.
regularizer: default regularizer passed to get_variable.
reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in
- get_variable.
+ get_variable. In Eager mode, this argument is always forced to be
+ tf.AUTO_REUSE.
caching_device: string, callable, or None: the caching device passed to
get_variable.
partitioner: callable or `None`: the partitioner passed to `get_variable`.
@@ -862,7 +877,8 @@ class VariableScope(object):
dtype: default type passed to get_variable (defaults to DT_FLOAT).
use_resource: if False, create a normal Variable; if True create an
experimental ResourceVariable with well-defined semantics. Defaults
- to False (will later change to True).
+ to False (will later change to True). In Eager mode, this argument is
+ always forced to be True.
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
@@ -903,6 +919,7 @@ class VariableScope(object):
if self._partitioner is not None:
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
+ self._reuse = AUTO_REUSE
self._use_resource = True
@property
@@ -963,6 +980,8 @@ class VariableScope(object):
def set_use_resource(self, use_resource):
"""Sets whether to use ResourceVariables for this scope."""
+ if context.in_eager_mode() and not use_resource:
+ raise ValueError("In eager mode, use_resource cannot be set to false.")
self._use_resource = use_resource
def set_regularizer(self, regularizer):
@@ -1029,8 +1048,14 @@ class VariableScope(object):
partitioner = self._partitioner
if custom_getter is None:
custom_getter = self._custom_getter
- if reuse is None:
- reuse = self._reuse
+ if context.in_graph_mode():
+ if reuse is None:
+ reuse = self._reuse
+ if use_resource is None:
+ use_resource = self._use_resource
+ else:
+ reuse = AUTO_REUSE
+ use_resource = True
full_name = self.name + "/" + name if self.name else name
# Variable names only depend on variable_scope (full_name here),
@@ -1050,12 +1075,6 @@ class VariableScope(object):
constraint = self._constraint
if dtype is None:
dtype = self._dtype
- if context.in_graph_mode():
- if use_resource is None:
- use_resource = self._use_resource
- else:
- use_resource = True
-
return var_store.get_variable(
full_name, shape=shape, dtype=dtype, initializer=initializer,
regularizer=regularizer, reuse=reuse, trainable=trainable,
@@ -1232,7 +1251,8 @@ Args:
must be known.
use_resource: If False, creates a regular Variable. If true, creates an
experimental ResourceVariable instead with well-defined semantics.
- Defaults to False (will later change to True).
+ Defaults to False (will later change to True). In Eager mode, this argument
+ is always forced to be True.
custom_getter: Callable that takes as a first argument the true getter, and
allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
@@ -1661,12 +1681,14 @@ def variable_scope(name_or_scope,
reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode
for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
variables if they do not exist, and return them otherwise; if None, we
- inherit the parent scope's reuse flag.
+ inherit the parent scope's reuse flag. In Eager mode, this argument is
+ always forced to be tf.AUTO_REUSE.
dtype: type of variables created in this scope (defaults to the type
in the passed scope, or inherited from parent scope).
use_resource: If False, all variables will be regular Variables. If True,
experimental ResourceVariables with well-defined semantics will be used
- instead. Defaults to False (will later change to True).
+ instead. Defaults to False (will later change to True). In Eager mode,
+ this argument is always forced to be True.
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must