aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-27 13:18:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:23:04 -0700
commit4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch)
tree56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/python/ops
parentc898e63d07fc63315be98f0772736e5d7f2fb44c (diff)
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/gradients_test.py2
-rw-r--r--tensorflow/python/ops/matmul_benchmark.py8
-rw-r--r--tensorflow/python/ops/variable_scope.py117
-rw-r--r--tensorflow/python/ops/variables.py323
4 files changed, 386 insertions, 64 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 4f6e5dc473..3c9b7a01c7 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -273,7 +273,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
def testVariableRefGradient(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
- var = variables.Variable(init)
+ var = variables.VariableV1(init)
gradient = gradients.gradients(var._ref(), var)
self.assertIsNotNone(gradient)
diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py
index 6e5fe74290..138149e63d 100644
--- a/tensorflow/python/ops/matmul_benchmark.py
+++ b/tensorflow/python/ops/matmul_benchmark.py
@@ -49,13 +49,13 @@ def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
"""
with ops.device('%s' % device):
if not transpose_a:
- x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([n, m], dtype=dtype))
else:
- x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([m, n], dtype=dtype))
if not transpose_b:
- y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([m, k], dtype=dtype))
else:
- y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([k, m], dtype=dtype))
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
return control_flow_ops.group(z)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 562e1ad6cb..af5c7d4050 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -198,7 +198,7 @@ VariableSynchronization = variables.VariableSynchronization # pylint: disable=i
VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
-tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
+tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
get_variable() should create the requested variable if it doesn't exist or, if
@@ -908,7 +908,7 @@ class _VariableStore(object):
if use_resource is None:
# Set the default value if unspecified.
use_resource = _DEFAULT_USE_RESOURCE
- v = variable(
+ v = variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@@ -994,7 +994,7 @@ def no_regularizer(_):
# TODO(alive): support caching devices and partitioned variables in Eager mode.
-@tf_export("VariableScope")
+@tf_export(v1=["VariableScope"])
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@@ -1342,7 +1342,7 @@ def get_variable_scope_store():
return scope_store
-@tf_export("get_variable_scope")
+@tf_export(v1=["get_variable_scope"])
def get_variable_scope():
"""Returns the current variable scope."""
return get_variable_scope_store().current_scope
@@ -1451,7 +1451,7 @@ class EagerVariableStore(object):
# The argument list for get_variable must match arguments to get_local_variable.
# So, if you are updating the arguments, also update arguments to
# get_local_variable below.
-@tf_export("get_variable")
+@tf_export(v1=["get_variable"])
def get_variable(name,
shape=None,
dtype=None,
@@ -1596,7 +1596,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
-@tf_export("get_local_variable")
+@tf_export(v1=["get_local_variable"])
def get_local_variable( # pylint: disable=missing-docstring
name,
shape=None,
@@ -1941,7 +1941,7 @@ def _get_unique_variable_scope(prefix):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-@tf_export("variable_scope") # pylint: disable=invalid-name
+@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name
class variable_scope(object):
"""A context manager for defining ops that creates variables (layers).
@@ -2322,7 +2322,7 @@ class variable_scope(object):
# pylint: disable=g-doc-return-or-yield
-@tf_export("variable_op_scope")
+@tf_export(v1=["variable_op_scope"])
@tf_contextlib.contextmanager
def variable_op_scope(values,
name_or_scope,
@@ -2443,7 +2443,33 @@ def default_variable_creator(next_creator=None, **kwargs):
expected_shape=expected_shape, import_scope=import_scope)
+def default_variable_creator_v2(next_creator=None, **kwargs):
+ """Default variable creator."""
+ assert next_creator is None
+ initial_value = kwargs.get("initial_value", None)
+ trainable = kwargs.get("trainable", None)
+ validate_shape = kwargs.get("validate_shape", True)
+ caching_device = kwargs.get("caching_device", None)
+ name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
+ dtype = kwargs.get("dtype", None)
+ import_scope = kwargs.get("import_scope", None)
+ constraint = kwargs.get("constraint", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
+ return resource_variable_ops.ResourceVariable(
+ initial_value=initial_value, trainable=trainable,
+ validate_shape=validate_shape, caching_device=caching_device,
+ name=name, dtype=dtype, constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
+
+
variables.default_variable_creator = default_variable_creator
+variables.default_variable_creator_v2 = default_variable_creator_v2
def _make_getter(captured_getter, captured_previous):
@@ -2452,11 +2478,12 @@ def _make_getter(captured_getter, captured_previous):
# TODO(apassos) remove forwarding symbol
-variable = variables.Variable
+variable = variables.VariableV1
+@tf_export(v1=["variable_creator_scope"])
@tf_contextlib.contextmanager
-def variable_creator_scope(variable_creator):
+def variable_creator_scope_v1(variable_creator):
"""Scope which defines a variable creation function to be used by variable().
variable_creator is expected to be a function with the following signature:
@@ -2527,3 +2554,73 @@ def variable_creator_scope(variable_creator):
"""
with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
yield
+
+
+# Note: only the docstrings differ between this and v1.
+@tf_export(v2=["variable_creator_scope"])
+@tf_contextlib.contextmanager
+def variable_creator_scope(variable_creator):
+ """Scope which defines a variable creation function to be used by variable().
+
+ variable_creator is expected to be a function with the following signature:
+
+ ```
+ def variable_creator(next_creator, **kwargs)
+ ```
+
+ The creator is supposed to eventually call the next_creator to create a
+ variable if it does want to create a variable and not call Variable or
+ ResourceVariable directly. This helps make creators composable. A creator may
+ choose to create multiple variables, return already existing variables, or
+ simply register that a variable was created and defer to the next creators in
+ line. Creators can also modify the keyword arguments seen by the next
+ creators.
+
+ Custom getters in the variable scope will eventually resolve down to these
+ custom creators when they do create variables.
+
+ The valid keyword arguments in kwds are:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, GradientTapes automatically watch
+ uses of this Variable.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ constraint: A constraint function to be applied to the variable after
+ updates by some algorithms.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
+
+ This set may grow over time, so it's important the signature of creators is as
+ mentioned above.
+
+ Args:
+ variable_creator: the passed creator
+
+ Yields:
+ A scope in which the creator is active
+ """
+ with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
+ yield
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7a46157739..8da1e9fe56 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -46,6 +46,11 @@ def default_variable_creator(_, **kwds):
raise NotImplementedError("variable_scope needs to be imported")
+def default_variable_creator_v2(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
+
+
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(**kwargs):
@@ -101,21 +106,21 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
- def _variable_call(cls,
- initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- variable_def=None,
- dtype=None,
- expected_shape=None,
- import_scope=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
+ def _variable_v1_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
@@ -140,14 +145,49 @@ class VariableMetaclass(type):
synchronization=synchronization,
aggregation=aggregation)
+ def _variable_v2_call(cls,
+ initial_value=None,
+ trainable=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ import_scope=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ import_scope=import_scope,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
def __call__(cls, *args, **kwargs):
- if cls is Variable:
- return cls._variable_call(*args, **kwargs)
+ if cls is VariableV1:
+ return cls._variable_v1_call(*args, **kwargs)
+ elif cls is Variable:
+ return cls._variable_v2_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
-@tf_export("Variable")
+@tf_export(v2=["Variable"])
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
@@ -267,16 +307,13 @@ class Variable(six.with_metaclass(VariableMetaclass,
def __init__(self,
initial_value=None,
trainable=True,
- collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
- expected_shape=None,
import_scope=None,
constraint=None,
- use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
@@ -297,11 +334,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
- trainable: If `True`, the default, also adds the variable to the graph
- collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
- the default list of variables to use by the `Optimizer` classes.
- collections: List of graph collections keys. The new variable is added to
- these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ trainable: If `True`, the default, GradientTapes automatically watch uses
+ of this variable.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
@@ -319,8 +353,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype: If set, initial_value will be converted to the given type.
If `None`, either the datatype will be kept (if `initial_value` is
a Tensor), or `convert_to_tensor` will decide.
- expected_shape: A TensorShape. If set, initial_value is expected
- to have this shape.
import_scope: Optional `string`. Name scope to add to the
`Variable.` Only used when initializing from protocol buffer.
constraint: An optional projection function to be applied to the variable
@@ -330,9 +362,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
- use_resource: if True, a ResourceVariable is created; otherwise an
- old-style ref-based variable is created. When eager execution is enabled
- a resource variable is always created.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
@@ -1009,11 +1038,207 @@ class Variable(six.with_metaclass(VariableMetaclass,
raise NotImplementedError
+@tf_export(v1=["Variable"])
+class VariableV1(Variable):
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + y)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `global_variables_initializer()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize global variables.
+ init_op = tf.global_variables_initializer()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes global variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
+ `global_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ between variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+ WARNING: tf.Variable objects by default have a non-intuitive memory model. A
+ Variable is represented internally as a mutable Tensor which can
+ non-deterministically alias other Tensors in a graph. The set of operations
+ which consume a Variable and can lead to aliasing is undetermined and can
+ change across TensorFlow versions. Avoid writing code which relies on the
+ value of a Variable either changing or not changing as other operations
+ happen. For example, using Variable objects or simple functions thereof as
+ predicates in a `tf.cond` is dangerous and error-prone:
+
+ ```
+ v = tf.Variable(True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
+ ```
+
+ Here replacing adding `use_resource=True` when constructing the variable will
+ fix any nondeterminism issues:
+ ```
+ v = tf.Variable(True, use_resource=True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn)
+ ```
+
+ To use the replacement for variables which does
+ not have these issues:
+
+ * Add `use_resource=True` when constructing `tf.Variable`;
+ * Call `tf.get_variable_scope().set_use_resource(True)` inside a
+ `tf.variable_scope` before the `tf.get_variable()` call.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ variable_def: `VariableDef` protocol buffer. If not `None`, recreates
+ the Variable object with its contents, referencing the variable's nodes
+ in the graph, which must already exist. The graph is not changed.
+ `variable_def` and the other arguments are mutually exclusive.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ expected_shape: A TensorShape. If set, initial_value is expected
+ to have this shape.
+ import_scope: Optional `string`. Name scope to add to the
+ `Variable.` Only used when initializing from protocol buffer.
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ use_resource: whether to use resource variables.
+ synchronization: unused
+ aggregation: unused
+
+ Raises:
+ ValueError: If both `variable_def` and initial_value are specified.
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If eager execution is enabled.
+ """
+
+ SaveSliceInfo = Variable.SaveSliceInfo
+
+
# TODO(apassos): do not repeat all comments here
-class RefVariable(Variable):
+class RefVariable(VariableV1):
"""Ref-based implementation of variables."""
- def __init__(self,
+ def __init__(self, # pylint: disable=super-init-not-called
initial_value=None,
trainable=True,
collections=None,
@@ -1873,7 +2098,7 @@ class RefVariable(Variable):
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator)
+ Variable._OverloadOperator(operator) # pylint: disable=protected-access
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
@@ -2401,7 +2626,7 @@ class PartitionedVariable(object):
"assign() has not been implemented for PartitionedVariable.")
-@tf_export("global_variables")
+@tf_export(v1=["global_variables"])
def global_variables(scope=None):
"""Returns global variables.
@@ -2427,7 +2652,7 @@ def global_variables(scope=None):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
-@tf_export("all_variables")
+@tf_export(v1=["all_variables"])
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
@@ -2452,7 +2677,7 @@ def _all_saveable_objects(scope=None):
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
-@tf_export("local_variables")
+@tf_export(v1=["local_variables"])
def local_variables(scope=None):
"""Returns local variables.
@@ -2480,7 +2705,7 @@ def local_variables(scope=None):
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
-@tf_export("model_variables")
+@tf_export(v1=["model_variables"])
def model_variables(scope=None):
"""Returns all variables in the MODEL_VARIABLES collection.
@@ -2497,7 +2722,7 @@ def model_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
-@tf_export("trainable_variables")
+@tf_export(v1=["trainable_variables"])
def trainable_variables(scope=None):
"""Returns all variables created with `trainable=True`.
@@ -2519,7 +2744,7 @@ def trainable_variables(scope=None):
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
-@tf_export("moving_average_variables")
+@tf_export(v1=["moving_average_variables"])
def moving_average_variables(scope=None):
"""Returns all variables that maintain their moving averages.
@@ -2541,7 +2766,7 @@ def moving_average_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
-@tf_export("initializers.variables", "variables_initializer")
+@tf_export(v1=["initializers.variables", "variables_initializer"])
def variables_initializer(var_list, name="init"):
"""Returns an Op that initializes a list of variables.
@@ -2567,7 +2792,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
-@tf_export("initialize_variables")
+@tf_export(v1=["initialize_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
@@ -2575,7 +2800,7 @@ def initialize_variables(var_list, name="init"):
return variables_initializer(var_list, name=name)
-@tf_export("initializers.global_variables", "global_variables_initializer")
+@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
def global_variables_initializer():
"""Returns an Op that initializes global variables.
@@ -2589,7 +2814,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
-@tf_export("initialize_all_variables")
+@tf_export(v1=["initialize_all_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
@@ -2597,7 +2822,7 @@ def initialize_all_variables():
return global_variables_initializer()
-@tf_export("initializers.local_variables", "local_variables_initializer")
+@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
def local_variables_initializer():
"""Returns an Op that initializes all local variables.
@@ -2611,7 +2836,7 @@ def local_variables_initializer():
return variables_initializer(local_variables())
-@tf_export("initialize_local_variables")
+@tf_export(v1=["initialize_local_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
@@ -2619,7 +2844,7 @@ def initialize_local_variables():
return local_variables_initializer()
-@tf_export("is_variable_initialized")
+@tf_export(v1=["is_variable_initialized"])
@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -2634,7 +2859,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
-@tf_export("assert_variables_initialized")
+@tf_export(v1=["assert_variables_initialized"])
@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -2677,7 +2902,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
-@tf_export("report_uninitialized_variables")
+@tf_export(v1=["report_uninitialized_variables"])
@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):