aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r--tensorflow/python/ops/variables.py805
1 files changed, 677 insertions, 128 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d3172838a4..fc00ce68ae 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -17,6 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
+import six
+
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python.eager import context
@@ -36,8 +40,101 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
+def default_variable_creator(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
+
+
+def _make_getter(captured_getter, captured_previous):
+ """To avoid capturing loop variables."""
+ def getter(**kwargs):
+ return captured_getter(captured_previous, **kwargs)
+ return getter
+
+
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
+class VariableMetaclass(type):
+ """Metaclass to allow construction of tf.Variable to be overridden."""
+
+ def _variable_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ expected_shape=expected_shape,
+ import_scope=import_scope,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ def __call__(cls, *args, **kwargs):
+ if cls is Variable:
+ return cls._variable_call(*args, **kwargs)
+ else:
+ return super(VariableMetaclass, cls).__call__(*args, **kwargs)
+
+
@tf_export("Variable")
-class Variable(checkpointable.CheckpointableBase):
+class Variable(six.with_metaclass(VariableMetaclass,
+ checkpointable.CheckpointableBase)):
"""See the @{$variables$Variables How To} for a high level overview.
A variable maintains state in the graph across calls to `run()`. You add a
@@ -123,37 +220,33 @@ class Variable(checkpointable.CheckpointableBase):
various `Optimizer` classes use this collection as the default list of
variables to optimize.
- WARNING: tf.Variable objects have a non-intuitive memory model. A Variable is
- represented internally as a mutable Tensor which can non-deterministically
- alias other Tensors in a graph. The set of operations which consume a Variable
- and can lead to aliasing is undetermined and can change across TensorFlow
- versions. Avoid writing code which relies on the value of a Variable either
- changing or not changing as other operations happen. For example, using
- Variable objects or simple functions thereof as predicates in a `tf.cond` is
- dangerous and error-prone:
+ WARNING: tf.Variable objects by default have a non-intuitive memory model. A
+ Variable is represented internally as a mutable Tensor which can
+ non-deterministically alias other Tensors in a graph. The set of operations
+ which consume a Variable and can lead to aliasing is undetermined and can
+ change across TensorFlow versions. Avoid writing code which relies on the
+ value of a Variable either changing or not changing as other operations
+ happen. For example, using Variable objects or simple functions thereof as
+ predicates in a `tf.cond` is dangerous and error-prone:
```
v = tf.Variable(True)
tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
```
- Here replacing tf.Variable with tf.contrib.eager.Variable will fix any
- nondeterminism issues.
+ Here replacing adding `use_resource=True` when constructing the variable will
+ fix any nondeterminism issues:
+ ```
+ v = tf.Variable(True, use_resource=True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn)
+ ```
To use the replacement for variables which does
not have these issues:
- * Replace `tf.Variable` with `tf.contrib.eager.Variable`;
+ * Add `use_resource=True` when constructing `tf.Variable`;
* Call `tf.get_variable_scope().set_use_resource(True)` inside a
`tf.variable_scope` before the `tf.get_variable()` call.
-
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tf.contrib.eager.Variable` instead which is compatible with both eager
- execution and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
"""
def __init__(self,
@@ -167,7 +260,10 @@ class Variable(checkpointable.CheckpointableBase):
dtype=None,
expected_shape=None,
import_scope=None,
- constraint=None):
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
The new variable is added to the graph collections listed in `collections`,
@@ -219,25 +315,565 @@ class Variable(checkpointable.CheckpointableBase):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ use_resource: if True, a ResourceVariable is created; otherwise an
+ old-style ref-based variable is created. When eager execution is enabled
+ a resource variable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Raises:
ValueError: If both `variable_def` and initial_value are specified.
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
RuntimeError: If eager execution is enabled.
+ """
+ raise NotImplementedError
+
+ def __repr__(self):
+ raise NotImplementedError
+
+ def value(self):
+ """Returns the last snapshot of this variable.
+
+ You usually do not need to call this method as all ops that need the value
+ of the variable call it automatically through a `convert_to_tensor()` call.
+
+ Returns a `Tensor` which holds the value of the variable. You can not
+ assign a new value to this tensor as it is not a reference to the variable.
+
+ To avoid copies, if the consumer of the returned value is on the same device
+ as the variable, this actually returns the live value of the variable, not
+ a copy. Updates to the variable are seen by the consumer. If the consumer
+ is on a different device it will get a copy of the variable.
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tfe.Variable` instead which is compatible with both eager execution
- and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
+ Returns:
+ A `Tensor` containing the value of the variable.
+ """
+ raise NotImplementedError
+
+ def read_value(self):
+ """Returns the value of this variable, read in the current context.
+
+ Can be different from value() if it's on another device, with control
+ dependencies, etc.
+
+ Returns:
+ A `Tensor` containing the value of the variable.
+ """
+ raise NotImplementedError
+
+ def set_shape(self, shape):
+ """Overrides the shape for this variable.
+
+ Args:
+ shape: the `TensorShape` representing the overridden shape.
+ """
+ raise NotImplementedError
+
+ @property
+ def trainable(self):
+ raise NotImplementedError
+
+ def eval(self, session=None):
+ """In a session, computes and returns the value of this variable.
+
+ This is not a graph construction method, it does not add ops to the graph.
+
+ This convenience method requires a session where the graph
+ containing this variable has been launched. If no session is
+ passed, the default session is used. See @{tf.Session} for more
+ information on launching a graph and on sessions.
+
+ ```python
+ v = tf.Variable([1, 2])
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ # Usage passing the session explicitly.
+ print(v.eval(sess))
+ # Usage with the default session. The 'with' block
+ # above makes 'sess' the default session.
+ print(v.eval())
+ ```
+
+ Args:
+ session: The session to use to evaluate this variable. If
+ none, the default session is used.
+
+ Returns:
+ A numpy `ndarray` with a copy of the value of this variable.
+ """
+ raise NotImplementedError
+
+ def initialized_value(self):
+ """Returns the value of the initialized variable.
+
+ You should use this instead of the variable itself to initialize another
+ variable with a value that depends on the value of this variable.
+
+ ```python
+ # Initialize 'v' with a random tensor.
+ v = tf.Variable(tf.truncated_normal([10, 40]))
+ # Use `initialized_value` to guarantee that `v` has been
+ # initialized before its value is used to initialize `w`.
+ # The random values are picked only once.
+ w = tf.Variable(v.initialized_value() * 2.0)
+ ```
+
+ Returns:
+ A `Tensor` holding the value of this variable after its initializer
+ has run.
+ """
+ raise NotImplementedError
+
+ @property
+ def initial_value(self):
+ """Returns the Tensor used as the initial value for the variable.
+
+ Note that this is different from `initialized_value()` which runs
+ the op that initializes the variable before returning its value.
+ This method returns the tensor that is used by the op that initializes
+ the variable.
+
+ Returns:
+ A `Tensor`.
+ """
+ raise NotImplementedError
+
+ @property
+ def constraint(self):
+ """Returns the constraint function associated with this variable.
+
+ Returns:
+ The constraint function that was passed to the variable constructor.
+ Can be `None` if no constraint was passed.
+ """
+ raise NotImplementedError
+
+ def assign(self, value, use_locking=False):
+ """Assigns a new value to the variable.
+
+ This is essentially a shortcut for `assign(self, value)`.
+
+ Args:
+ value: A `Tensor`. The new value for this variable.
+ use_locking: If `True`, use locking during the assignment.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the assignment has completed.
+ """
+ raise NotImplementedError
+
+ def assign_add(self, delta, use_locking=False):
+ """Adds a value to this variable.
+
+ This is essentially a shortcut for `assign_add(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to add to this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the addition has completed.
+ """
+ raise NotImplementedError
+
+ def assign_sub(self, delta, use_locking=False):
+ """Subtracts a value from this variable.
+
+ This is essentially a shortcut for `assign_sub(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to subtract from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the subtraction has completed.
+ """
+ raise NotImplementedError
+
+ def scatter_sub(self, sparse_delta, use_locking=False):
+ """Subtracts `IndexedSlices` from this variable.
+
+ This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
+ sparse_delta.values)`.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be subtracted from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def count_up_to(self, limit):
+ """Increments this variable until it reaches `limit`.
+
+ When that Op is run it tries to increment the variable by `1`. If
+ incrementing the variable would bring it above `limit` then the Op raises
+ the exception `OutOfRangeError`.
+
+ If no error is raised, the Op outputs the value of the variable before
+ the increment.
+
+ This is essentially a shortcut for `count_up_to(self, limit)`.
+
+ Args:
+ limit: value at which incrementing the variable raises an error.
+
+ Returns:
+ A `Tensor` that will hold the variable value before the increment. If no
+ other Op modifies this variable, the values produced will all be
+ distinct.
+ """
+ raise NotImplementedError
+
+ def load(self, value, session=None):
+ """Load new value into this variable.
+
+ Writes new value to variable's memory. Doesn't add ops to the graph.
+
+ This convenience method requires a session where the graph
+ containing this variable has been launched. If no session is
+ passed, the default session is used. See @{tf.Session} for more
+ information on launching a graph and on sessions.
+
+ ```python
+ v = tf.Variable([1, 2])
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ # Usage passing the session explicitly.
+ v.load([2, 3], sess)
+ print(v.eval(sess)) # prints [2 3]
+ # Usage with the default session. The 'with' block
+ # above makes 'sess' the default session.
+ v.load([3, 4], sess)
+ print(v.eval()) # prints [3 4]
+ ```
+
+ Args:
+ value: New variable value
+ session: The session to use to evaluate this variable. If
+ none, the default session is used.
+
+ Raises:
+ ValueError: Session is not passed and no default session
+ """
+ raise NotImplementedError
+
+ # Conversion to tensor.
+ @staticmethod
+ def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
+ """Utility function for converting a Variable to a Tensor."""
+ _ = name
+ if dtype and not dtype.is_compatible_with(v.dtype):
+ raise ValueError(
+ "Incompatible type conversion requested to type '%s' for variable "
+ "of type '%s'" % (dtype.name, v.dtype.name))
+ if as_ref:
+ return v._ref() # pylint: disable=protected-access
+ else:
+ return v.value()
+
+ @staticmethod
+ def _OverloadAllOperators(): # pylint: disable=invalid-name
+ """Register overloads for all operators."""
+ for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
+ Variable._OverloadOperator(operator)
+ # For slicing, bind getitem differently than a tensor (use SliceHelperVar
+ # instead)
+ # pylint: disable=protected-access
+ setattr(Variable, "__getitem__", array_ops._SliceHelperVar)
+
+ @staticmethod
+ def _OverloadOperator(operator): # pylint: disable=invalid-name
+ """Defer an operator overload to `ops.Tensor`.
+
+ We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
+
+ Args:
+ operator: string. The operator name.
+ """
+
+ def _run_op(a, *args):
+ # pylint: disable=protected-access
+ return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
+ # Propagate __doc__ to wrapper
+ try:
+ _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
+ except AttributeError:
+ pass
+
+ setattr(Variable, operator, _run_op)
+
+ # NOTE(mrry): This enables the Variable's overloaded "right" binary
+ # operators to run when the left operand is an ndarray, because it
+ # accords the Variable class higher priority than an ndarray, or a
+ # numpy matrix.
+ # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
+ # mechanism, which allows more control over how Variables interact
+ # with ndarrays.
+ __array_priority__ = 100
+
+ @property
+ def name(self):
+ """The name of this variable."""
+ raise NotImplementedError
+
+ @property
+ def initializer(self):
+ """The initializer operation for this variable."""
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ """The device of this variable."""
+ raise NotImplementedError
+
+ @property
+ def dtype(self):
+ """The `DType` of this variable."""
+ raise NotImplementedError
+
+ @property
+ def op(self):
+ """The `Operation` of this variable."""
+ raise NotImplementedError
+
+ @property
+ def graph(self):
+ """The `Graph` of this variable."""
+ raise NotImplementedError
+
+ @property
+ def shape(self):
+ """The `TensorShape` of this variable.
+
+ Returns:
+ A `TensorShape`.
+ """
+ raise NotImplementedError
+
+ def get_shape(self):
+ """Alias of Variable.shape."""
+ raise NotImplementedError
+
+ def to_proto(self, export_scope=None):
+ """Converts a `Variable` to a `VariableDef` protocol buffer.
+
+ Args:
+ export_scope: Optional `string`. Name scope to remove.
+
+ Returns:
+ A `VariableDef` protocol buffer, or `None` if the `Variable` is not
+ in the specified name scope.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def from_proto(variable_def, import_scope=None):
+ """Returns a `Variable` object created from `variable_def`."""
+ return RefVariable(variable_def=variable_def,
+ import_scope=import_scope)
+
+ class SaveSliceInfo(object):
+ """Information on how to save this Variable as a slice.
+
+ Provides internal support for saving variables as slices of a larger
+ variable. This API is not public and is subject to change.
+
+ Available properties:
+
+ * full_name
+ * full_shape
+ * var_offset
+ * var_shape
+ """
+
+ def __init__(self,
+ full_name=None,
+ full_shape=None,
+ var_offset=None,
+ var_shape=None,
+ save_slice_info_def=None,
+ import_scope=None):
+ """Create a `SaveSliceInfo`.
+
+ Args:
+ full_name: Name of the full variable of which this `Variable` is a
+ slice.
+ full_shape: Shape of the full variable, as a list of int.
+ var_offset: Offset of this `Variable` into the full variable, as a
+ list of int.
+ var_shape: Shape of this `Variable`, as a list of int.
+ save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
+ recreates the SaveSliceInfo object its contents.
+ `save_slice_info_def` and other arguments are mutually
+ exclusive.
+ import_scope: Optional `string`. Name scope to add. Only used
+ when initializing from protocol buffer.
+ """
+ if save_slice_info_def:
+ assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
+ self.full_name = ops.prepend_name_scope(
+ save_slice_info_def.full_name, import_scope=import_scope)
+ self.full_shape = [i for i in save_slice_info_def.full_shape]
+ self.var_offset = [i for i in save_slice_info_def.var_offset]
+ self.var_shape = [i for i in save_slice_info_def.var_shape]
+ else:
+ self.full_name = full_name
+ self.full_shape = full_shape
+ self.var_offset = var_offset
+ self.var_shape = var_shape
+
+ @property
+ def spec(self):
+ """Computes the spec string used for saving."""
+ full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
+ sl_spec = ":".join([
+ "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
+ ])
+ return full_shape_str + sl_spec
+
+ def to_proto(self, export_scope=None):
+ """Returns a SaveSliceInfoDef() proto.
+
+ Args:
+ export_scope: Optional `string`. Name scope to remove.
+
+ Returns:
+ A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
+ in the specified name scope.
+ """
+ if (export_scope is None or
+ self.full_name.startswith(export_scope)):
+ save_slice_info_def = variable_pb2.SaveSliceInfoDef()
+ save_slice_info_def.full_name = ops.strip_name_scope(
+ self.full_name, export_scope)
+ for i in self.full_shape:
+ save_slice_info_def.full_shape.append(i)
+ for i in self.var_offset:
+ save_slice_info_def.var_offset.append(i)
+ for i in self.var_shape:
+ save_slice_info_def.var_shape.append(i)
+ return save_slice_info_def
+ else:
+ return None
+
+ def __iadd__(self, other):
+ raise NotImplementedError
+
+ def __isub__(self, other):
+ raise NotImplementedError
+
+ def __imul__(self, other):
+ raise NotImplementedError
+
+ def __idiv__(self, other):
+ raise NotImplementedError
+
+ def __itruediv__(self, other):
+ raise NotImplementedError
+
+ def __irealdiv__(self, other):
+ raise NotImplementedError
+
+ def __ipow__(self, other):
+ raise NotImplementedError
+
+
+# TODO(apassos): do not repeat all comments here
+class RefVariable(Variable):
+ """Ref-based implementation of variables."""
+
+ def __init__(self,
+ initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ variable_def: `VariableDef` protocol buffer. If not `None`, recreates
+ the Variable object with its contents, referencing the variable's nodes
+ in the graph, which must already exist. The graph is not changed.
+ `variable_def` and the other arguments are mutually exclusive.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ expected_shape: A TensorShape. If set, initial_value is expected
+ to have this shape.
+ import_scope: Optional `string`. Name scope to add to the
+ `Variable.` Only used when initializing from protocol buffer.
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+
+ Raises:
+ ValueError: If both `variable_def` and initial_value are specified.
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- "tf.Variable not supported when eager execution is enabled. "
- "Please use tf.contrib.eager.Variable instead")
self._in_graph_mode = True
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
@@ -348,8 +984,7 @@ class Variable(checkpointable.CheckpointableBase):
# Ensure that we weren't lifted into the eager context.
if context.executing_eagerly():
raise RuntimeError(
- "tf.Variable not supported when eager execution is enabled. "
- "Please use tf.contrib.eager.Variable instead")
+ "RefVariable not supported when eager execution is enabled. ")
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
@@ -1068,12 +1703,6 @@ class Variable(checkpointable.CheckpointableBase):
else:
return None
- @staticmethod
- def from_proto(variable_def, import_scope=None):
- """Returns a `Variable` object created from `variable_def`."""
- return Variable(variable_def=variable_def,
- import_scope=import_scope)
-
def __iadd__(self, other):
logging.log_first_n(
logging.WARN,
@@ -1130,90 +1759,6 @@ class Variable(checkpointable.CheckpointableBase):
" if you want a new python Tensor object.", 1)
return self ** other
- class SaveSliceInfo(object):
- """Information on how to save this Variable as a slice.
-
- Provides internal support for saving variables as slices of a larger
- variable. This API is not public and is subject to change.
-
- Available properties:
-
- * full_name
- * full_shape
- * var_offset
- * var_shape
- """
-
- def __init__(self,
- full_name=None,
- full_shape=None,
- var_offset=None,
- var_shape=None,
- save_slice_info_def=None,
- import_scope=None):
- """Create a `SaveSliceInfo`.
-
- Args:
- full_name: Name of the full variable of which this `Variable` is a
- slice.
- full_shape: Shape of the full variable, as a list of int.
- var_offset: Offset of this `Variable` into the full variable, as a
- list of int.
- var_shape: Shape of this `Variable`, as a list of int.
- save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
- recreates the SaveSliceInfo object its contents.
- `save_slice_info_def` and other arguments are mutually
- exclusive.
- import_scope: Optional `string`. Name scope to add. Only used
- when initializing from protocol buffer.
- """
- if save_slice_info_def:
- assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
- self.full_name = ops.prepend_name_scope(
- save_slice_info_def.full_name, import_scope=import_scope)
- self.full_shape = [i for i in save_slice_info_def.full_shape]
- self.var_offset = [i for i in save_slice_info_def.var_offset]
- self.var_shape = [i for i in save_slice_info_def.var_shape]
- else:
- self.full_name = full_name
- self.full_shape = full_shape
- self.var_offset = var_offset
- self.var_shape = var_shape
-
- @property
- def spec(self):
- """Computes the spec string used for saving."""
- full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
- sl_spec = ":".join([
- "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
- ])
- return full_shape_str + sl_spec
-
- def to_proto(self, export_scope=None):
- """Returns a SaveSliceInfoDef() proto.
-
- Args:
- export_scope: Optional `string`. Name scope to remove.
-
- Returns:
- A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
- in the specified name scope.
- """
- if (export_scope is None or
- self.full_name.startswith(export_scope)):
- save_slice_info_def = variable_pb2.SaveSliceInfoDef()
- save_slice_info_def.full_name = ops.strip_name_scope(
- self.full_name, export_scope)
- for i in self.full_shape:
- save_slice_info_def.full_shape.append(i)
- for i in self.var_offset:
- save_slice_info_def.var_offset.append(i)
- for i in self.var_shape:
- save_slice_info_def.var_shape.append(i)
- return save_slice_info_def
- else:
- return None
-
def _set_save_slice_info(self, save_slice_info):
"""Sets the slice info for this `Variable`.
@@ -1230,7 +1775,7 @@ class PartitionedVariable(object):
"""A container for partitioned `Variable` objects.
@compatibility(eager) `tf.PartitionedVariable` is not compatible with
- eager execution. Use `tfe.Variable` instead which is compatible
+ eager execution. Use `tf.Variable` instead which is compatible
with both eager execution and graph construction. See [the
TensorFlow Eager Execution
guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
@@ -1404,6 +1949,10 @@ class PartitionedVariable(object):
def dtype(self):
return self._dtype
+ @property
+ def shape(self):
+ return self.get_shape()
+
def get_shape(self):
return self._shape
@@ -1723,6 +2272,8 @@ def report_uninitialized_variables(var_list=None,
var_list.append(op.outputs[0])
with ops.name_scope(name):
# Run all operations on CPU
+ if var_list:
+ init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
with ops.device("/cpu:0"):
if not var_list:
# Return an empty tensor so we only need to check for returned tensor
@@ -1730,9 +2281,7 @@ def report_uninitialized_variables(var_list=None,
return array_ops.constant([], dtype=dtypes.string)
else:
# Get a 1-D boolean tensor listing whether each variable is initialized.
- variables_mask = math_ops.logical_not(
- array_ops.stack(
- [state_ops.is_variable_initialized(v) for v in var_list]))
+ variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
# Get a 1-D string tensor containing all the variable names.
variable_names_tensor = array_ops.constant(
[s.op.name for s in var_list])