aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-24 13:55:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 14:02:32 -0700
commit134daeb4151349acf8c2b3c22f5aebc3e429d756 (patch)
treedf78ef8b62a2db297a85329b5902ae7a4583c789
parentbf1fad214febef6af5c101d8f953d0109c46dfbb (diff)
Eager reuse story is False instead of AUTO_REUSE.
We want variables with eager execution to have object semantics instead of name semantics and this is a small step in that direction. This means that the functional style layer invocations (tf.layers.dense() etc.) will NOT work when eager execution is enabled. Instead, use of the object-oriented layers is advised. PiperOrigin-RevId: 173306447
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py62
-rw-r--r--tensorflow/python/ops/variable_scope.py62
2 files changed, 60 insertions, 64 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 0ea58b4402..29f583d5ba 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -39,7 +39,6 @@ from tensorflow.python.platform import test
class VariableScopeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
def testGetVar(self):
vs = variable_scope._get_default_variable_store()
v = vs.get_variable("v", [1])
@@ -52,7 +51,6 @@ class VariableScopeTest(test.TestCase):
v1 = vs.get_variable("v", [1], use_resource=True)
self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
- @test_util.run_in_graph_and_eager_modes()
def testNameExists(self):
vs = variable_scope._get_default_variable_store()
# No check by default, so we can both create and get existing names.
@@ -60,15 +58,14 @@ class VariableScopeTest(test.TestCase):
v1 = vs.get_variable("v", [1])
self.assertEqual(v, v1)
- if context.in_graph_mode():
- # When reuse is False, we fail when variables are already there.
- vs.get_variable("w", [1], reuse=False) # That's ok.
- with self.assertRaises(ValueError):
- vs.get_variable("v", [1], reuse=False) # That fails.
- # When reuse is True, we fail when variables are new.
- vs.get_variable("v", [1], reuse=True) # That's ok.
- with self.assertRaises(ValueError):
- vs.get_variable("u", [1], reuse=True) # That fails.
+ # When reuse is False, we fail when variables are already there.
+ vs.get_variable("w", [1], reuse=False) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("v", [1], reuse=False) # That fails.
+ # When reuse is True, we fail when variables are new.
+ vs.get_variable("v", [1], reuse=True) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("u", [1], reuse=True) # That fails.
@test_util.run_in_graph_and_eager_modes()
def testNamelessStore(self):
@@ -224,10 +221,12 @@ class VariableScopeTest(test.TestCase):
self.assertAllClose(self.evaluate(losses[1]), 0.4)
self.assertAllClose(self.evaluate(losses[2]), 0.5)
with variable_scope.variable_scope("foo", reuse=True):
- v = variable_scope.get_variable("v",
- []) # "v" is alredy there, reused
- losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(3, len(losses)) # No new loss added.
+ # reuse=True is for now only supported when eager execution is disabled.
+ if context.in_graph_mode():
+ v = variable_scope.get_variable("v",
+ []) # "v" is alredy there, reused
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(3, len(losses)) # No new loss added.
@test_util.run_in_graph_and_eager_modes()
def testInitializeFromValue(self):
@@ -439,20 +438,20 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
self.assertFalse(jump_no_reuse.reuse)
- @test_util.run_in_graph_and_eager_modes()
def testVarScopeGetOrCreateReuse(self):
- def test_value(value):
- x = constant_op.constant(value)
- with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
- reuse=variable_scope.AUTO_REUSE):
- _ = state_ops.assign(variable_scope.get_variable("var", []), x)
- with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
- reuse=variable_scope.AUTO_REUSE):
- _ = variable_scope.get_variable("var", [])
- self.assertEqual(value, self.evaluate(x))
- test_value(42.) # Variable is created.
- test_value(13.) # Variable is reused hereafter.
- test_value(17.)
+ with self.test_session():
+ def test_value(value):
+ x = constant_op.constant(value)
+ with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+ reuse=variable_scope.AUTO_REUSE):
+ _ = state_ops.assign(variable_scope.get_variable("var", []), x)
+ with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+ reuse=variable_scope.AUTO_REUSE):
+ _ = variable_scope.get_variable("var", [])
+ self.assertEqual(value, x.eval())
+ test_value(42.) # Variable is created.
+ test_value(13.) # Variable is reused hereafter.
+ test_value(17.)
def testVarOpScope(self):
with self.test_session():
@@ -745,9 +744,10 @@ class VariableScopeTest(test.TestCase):
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`.
- with variable_scope.variable_scope(outer, "default", reuse=True):
- self.assertEqual(
- variable_scope.get_local_variable("w", []).name, "outer/w:0")
+ if context.in_graph_mode():
+ with variable_scope.variable_scope(outer, "default", reuse=True):
+ self.assertEqual(
+ variable_scope.get_local_variable("w", []).name, "outer/w:0")
def testGetVarWithDevice(self):
g = ops.Graph()
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 22048a0cef..8c5c639b68 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -259,8 +259,8 @@ class _VariableStore(object):
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
- of variables. In Eager mode, this argument is always forced to be
- tf.AUTO_REUSE.
+ of variables. When eager execution is enabled this argument is always
+ forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
collections: List of graph collections keys to add the `Variable` to.
@@ -279,7 +279,8 @@ class _VariableStore(object):
use_resource: If False, creates a regular Variable. If True, creates
instead an experimental ResourceVariable which has well-defined
semantics. Defaults to False (will later change to True).
- In Eager mode, this argument is always forced to be true.
+ When eager execution is enabled this argument is always forced to be
+ true.
custom_getter: Callable that takes as a first argument the true getter,
and allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
@@ -314,7 +315,7 @@ class _VariableStore(object):
"Passed a custom_getter which is not callable: %s" % custom_getter)
if context.in_eager_mode():
- reuse = AUTO_REUSE
+ reuse = False
use_resource = True
# If a *_ref type is passed in an error would be triggered further down the
@@ -506,7 +507,7 @@ class _VariableStore(object):
"""
if context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
initializing_from_value = initializer is not None and isinstance(
initializer, ops.Tensor)
@@ -710,15 +711,6 @@ class _VariableStore(object):
Raises:
ValueError: See documentation of get_variable above.
"""
- # Fast-path for get_variable in eager mode when the variable already
- # exists. Note this skips error validation code, so mismatched shapes and
- # dtypes will be caught when the variable is used instead of when the call
- # to get_variable happens.
- if context.in_eager_mode():
- v = self._vars.get(name, None)
- if v is not None:
- return v
-
# Set to true if initializer is a constant.
initializing_from_value = False
if initializer is not None and not callable(initializer):
@@ -732,6 +724,9 @@ class _VariableStore(object):
if name in self._vars:
# Here we handle the case when returning an existing variable.
if reuse is False:
+ if context.in_eager_mode():
+ raise ValueError(
+ "Trying to recreate existing variable: %s" % self._vars[name])
tb = self._vars[name].op.traceback[::-1]
# Throw away internal tf entries and only take a few lines.
tb = [x for x in tb if "tensorflow/python" not in x[0]][:3]
@@ -875,8 +870,8 @@ class VariableScope(object):
initializer: default initializer passed to get_variable.
regularizer: default regularizer passed to get_variable.
reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in
- get_variable. In Eager mode, this argument is always forced to be
- tf.AUTO_REUSE.
+ get_variable. When eager execution is enabled this argument is always
+ forced to be False.
caching_device: string, callable, or None: the caching device passed to
get_variable.
partitioner: callable or `None`: the partitioner passed to `get_variable`.
@@ -885,8 +880,8 @@ class VariableScope(object):
dtype: default type passed to get_variable (defaults to DT_FLOAT).
use_resource: if False, create a normal Variable; if True create an
experimental ResourceVariable with well-defined semantics. Defaults
- to False (will later change to True). In Eager mode, this argument is
- always forced to be True.
+ to False (will later change to True). When eager execution is enabled
+ this argument is always forced to be True.
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
@@ -923,10 +918,10 @@ class VariableScope(object):
if context.in_eager_mode():
if self._caching_device is not None:
raise NotImplementedError("Caching devices is not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
if self._partitioner is not None:
raise NotImplementedError("Partitioned variables are not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
self._reuse = AUTO_REUSE
self._use_resource = True
@@ -989,7 +984,8 @@ class VariableScope(object):
def set_use_resource(self, use_resource):
"""Sets whether to use ResourceVariables for this scope."""
if context.in_eager_mode() and not use_resource:
- raise ValueError("In eager mode, use_resource cannot be set to false.")
+ raise ValueError("When eager execution is enabled, "
+ "use_resource cannot be set to false.")
self._use_resource = use_resource
def set_regularizer(self, regularizer):
@@ -1000,14 +996,14 @@ class VariableScope(object):
"""Set caching_device for this scope."""
if context.in_eager_mode():
raise NotImplementedError("Caching devices are not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
self._caching_device = caching_device
def set_partitioner(self, partitioner):
"""Set partitioner for this scope."""
if partitioner and context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
self._partitioner = partitioner
def set_custom_getter(self, custom_getter):
@@ -1062,7 +1058,7 @@ class VariableScope(object):
if use_resource is None:
use_resource = self._use_resource
else:
- reuse = AUTO_REUSE
+ reuse = False
use_resource = True
full_name = self.name + "/" + name if self.name else name
@@ -1108,7 +1104,7 @@ class VariableScope(object):
"""Gets an existing variable with this name or create a new one."""
if context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
- "in Eager mode.")
+ "when eager execution is enabled.")
if initializer is None:
initializer = self._initializer
if regularizer is None:
@@ -1259,8 +1255,8 @@ Args:
must be known.
use_resource: If False, creates a regular Variable. If true, creates an
experimental ResourceVariable instead with well-defined semantics.
- Defaults to False (will later change to True). In Eager mode, this argument
- is always forced to be True.
+ Defaults to False (will later change to True). When eager execution is
+ enabled this argument is always forced to be True.
custom_getter: Callable that takes as a first argument the true getter, and
allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
@@ -1721,14 +1717,14 @@ class variable_scope(object): # pylint: disable=invalid-name
reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode
for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
variables if they do not exist, and return them otherwise; if None, we
- inherit the parent scope's reuse flag. In Eager mode, this argument is
- always forced to be tf.AUTO_REUSE.
+ inherit the parent scope's reuse flag. When eager execution is enabled,
+ this argument is always forced to be tf.AUTO_REUSE.
dtype: type of variables created in this scope (defaults to the type
in the passed scope, or inherited from parent scope).
use_resource: If False, all variables will be regular Variables. If True,
experimental ResourceVariables with well-defined semantics will be used
- instead. Defaults to False (will later change to True). In Eager mode,
- this argument is always forced to be True.
+ instead. Defaults to False (will later change to True). When eager
+ execution is enabled this argument is always forced to be True.
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
@@ -1935,8 +1931,8 @@ def variable(initial_value=None,
caching_device=caching_device, name=name, dtype=dtype)
elif not use_resource and context.in_eager_mode():
raise RuntimeError(
- "VariableScope should use resource variable in Eager mode, but "
- "use_resource is False."
+ "VariableScope should use resource variable when eager execution is"
+ " enabled, but use_resource is False."
)
else:
return variables.Variable(