diff options
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r-- | tensorflow/python/ops/variables.py | 805 |
1 files changed, 677 insertions, 128 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index d3172838a4..fc00ce68ae 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -17,6 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order + +import six + from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python.eager import context @@ -36,8 +40,101 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export +def default_variable_creator(_, **kwds): + del kwds + raise NotImplementedError("variable_scope needs to be imported") + + +def _make_getter(captured_getter, captured_previous): + """To avoid capturing loop variables.""" + def getter(**kwargs): + return captured_getter(captured_previous, **kwargs) + return getter + + +@tf_export("VariableSynchronization") +class VariableSynchronization(enum.Enum): + """Indicates when a distributed variable will be synced.""" + + # Indicates that the synchronization will be determined by the current + # `DistributionStrategy` (eg. With `MirroredStrategy` this would be + # `ON_WRITE`). + AUTO = 0 + + # Indicates that there will only be one copy of the variable, so there is no + # need to sync. + NONE = 1 + + # Indicates that the variable will be aggregated across devices + # every time it is updated. + ON_WRITE = 2 + + # Indicates that the variable will be aggregated across devices + # when it is read (eg. when checkpointing or when evaluating an op that uses + # the variable). + ON_READ = 3 + + +@tf_export("VariableAggregation") +class VariableAggregation(enum.Enum): + """Indicates how a distributed variable will be aggregated.""" + NONE = 0 + SUM = 1 + MEAN = 2 + + +class VariableMetaclass(type): + """Metaclass to allow construction of tf.Variable to be overridden.""" + + def _variable_call(cls, + initial_value=None, + trainable=None, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + variable_def=None, + dtype=None, + expected_shape=None, + import_scope=None, + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): + """Call on Variable class. Useful to force the signature.""" + previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) + for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access + previous_getter = _make_getter(getter, previous_getter) + + # Reset `aggregation` that is explicitly set as `None` to the enum NONE. + if aggregation is None: + aggregation = VariableAggregation.NONE + return previous_getter( + initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, + variable_def=variable_def, + dtype=dtype, + expected_shape=expected_shape, + import_scope=import_scope, + constraint=constraint, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) + + def __call__(cls, *args, **kwargs): + if cls is Variable: + return cls._variable_call(*args, **kwargs) + else: + return super(VariableMetaclass, cls).__call__(*args, **kwargs) + + @tf_export("Variable") -class Variable(checkpointable.CheckpointableBase): +class Variable(six.with_metaclass(VariableMetaclass, + checkpointable.CheckpointableBase)): """See the @{$variables$Variables How To} for a high level overview. A variable maintains state in the graph across calls to `run()`. You add a @@ -123,37 +220,33 @@ class Variable(checkpointable.CheckpointableBase): various `Optimizer` classes use this collection as the default list of variables to optimize. - WARNING: tf.Variable objects have a non-intuitive memory model. A Variable is - represented internally as a mutable Tensor which can non-deterministically - alias other Tensors in a graph. The set of operations which consume a Variable - and can lead to aliasing is undetermined and can change across TensorFlow - versions. Avoid writing code which relies on the value of a Variable either - changing or not changing as other operations happen. For example, using - Variable objects or simple functions thereof as predicates in a `tf.cond` is - dangerous and error-prone: + WARNING: tf.Variable objects by default have a non-intuitive memory model. A + Variable is represented internally as a mutable Tensor which can + non-deterministically alias other Tensors in a graph. The set of operations + which consume a Variable and can lead to aliasing is undetermined and can + change across TensorFlow versions. Avoid writing code which relies on the + value of a Variable either changing or not changing as other operations + happen. For example, using Variable objects or simple functions thereof as + predicates in a `tf.cond` is dangerous and error-prone: ``` v = tf.Variable(True) tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. ``` - Here replacing tf.Variable with tf.contrib.eager.Variable will fix any - nondeterminism issues. + Here replacing adding `use_resource=True` when constructing the variable will + fix any nondeterminism issues: + ``` + v = tf.Variable(True, use_resource=True) + tf.cond(v, lambda: v.assign(False), my_false_fn) + ``` To use the replacement for variables which does not have these issues: - * Replace `tf.Variable` with `tf.contrib.eager.Variable`; + * Add `use_resource=True` when constructing `tf.Variable`; * Call `tf.get_variable_scope().set_use_resource(True)` inside a `tf.variable_scope` before the `tf.get_variable()` call. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tf.contrib.eager.Variable` instead which is compatible with both eager - execution and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ def __init__(self, @@ -167,7 +260,10 @@ class Variable(checkpointable.CheckpointableBase): dtype=None, expected_shape=None, import_scope=None, - constraint=None): + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Creates a new variable with value `initial_value`. The new variable is added to the graph collections listed in `collections`, @@ -219,25 +315,565 @@ class Variable(checkpointable.CheckpointableBase): variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. + use_resource: if True, a ResourceVariable is created; otherwise an + old-style ref-based variable is created. When eager execution is enabled + a resource variable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Raises: ValueError: If both `variable_def` and initial_value are specified. ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If eager execution is enabled. + """ + raise NotImplementedError + + def __repr__(self): + raise NotImplementedError + + def value(self): + """Returns the last snapshot of this variable. + + You usually do not need to call this method as all ops that need the value + of the variable call it automatically through a `convert_to_tensor()` call. + + Returns a `Tensor` which holds the value of the variable. You can not + assign a new value to this tensor as it is not a reference to the variable. + + To avoid copies, if the consumer of the returned value is on the same device + as the variable, this actually returns the live value of the variable, not + a copy. Updates to the variable are seen by the consumer. If the consumer + is on a different device it will get a copy of the variable. - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tfe.Variable` instead which is compatible with both eager execution - and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility + Returns: + A `Tensor` containing the value of the variable. + """ + raise NotImplementedError + + def read_value(self): + """Returns the value of this variable, read in the current context. + + Can be different from value() if it's on another device, with control + dependencies, etc. + + Returns: + A `Tensor` containing the value of the variable. + """ + raise NotImplementedError + + def set_shape(self, shape): + """Overrides the shape for this variable. + + Args: + shape: the `TensorShape` representing the overridden shape. + """ + raise NotImplementedError + + @property + def trainable(self): + raise NotImplementedError + + def eval(self, session=None): + """In a session, computes and returns the value of this variable. + + This is not a graph construction method, it does not add ops to the graph. + + This convenience method requires a session where the graph + containing this variable has been launched. If no session is + passed, the default session is used. See @{tf.Session} for more + information on launching a graph and on sessions. + + ```python + v = tf.Variable([1, 2]) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + # Usage passing the session explicitly. + print(v.eval(sess)) + # Usage with the default session. The 'with' block + # above makes 'sess' the default session. + print(v.eval()) + ``` + + Args: + session: The session to use to evaluate this variable. If + none, the default session is used. + + Returns: + A numpy `ndarray` with a copy of the value of this variable. + """ + raise NotImplementedError + + def initialized_value(self): + """Returns the value of the initialized variable. + + You should use this instead of the variable itself to initialize another + variable with a value that depends on the value of this variable. + + ```python + # Initialize 'v' with a random tensor. + v = tf.Variable(tf.truncated_normal([10, 40])) + # Use `initialized_value` to guarantee that `v` has been + # initialized before its value is used to initialize `w`. + # The random values are picked only once. + w = tf.Variable(v.initialized_value() * 2.0) + ``` + + Returns: + A `Tensor` holding the value of this variable after its initializer + has run. + """ + raise NotImplementedError + + @property + def initial_value(self): + """Returns the Tensor used as the initial value for the variable. + + Note that this is different from `initialized_value()` which runs + the op that initializes the variable before returning its value. + This method returns the tensor that is used by the op that initializes + the variable. + + Returns: + A `Tensor`. + """ + raise NotImplementedError + + @property + def constraint(self): + """Returns the constraint function associated with this variable. + + Returns: + The constraint function that was passed to the variable constructor. + Can be `None` if no constraint was passed. + """ + raise NotImplementedError + + def assign(self, value, use_locking=False): + """Assigns a new value to the variable. + + This is essentially a shortcut for `assign(self, value)`. + + Args: + value: A `Tensor`. The new value for this variable. + use_locking: If `True`, use locking during the assignment. + + Returns: + A `Tensor` that will hold the new value of this variable after + the assignment has completed. + """ + raise NotImplementedError + + def assign_add(self, delta, use_locking=False): + """Adds a value to this variable. + + This is essentially a shortcut for `assign_add(self, delta)`. + + Args: + delta: A `Tensor`. The value to add to this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the addition has completed. + """ + raise NotImplementedError + + def assign_sub(self, delta, use_locking=False): + """Subtracts a value from this variable. + + This is essentially a shortcut for `assign_sub(self, delta)`. + + Args: + delta: A `Tensor`. The value to subtract from this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the subtraction has completed. + """ + raise NotImplementedError + + def scatter_sub(self, sparse_delta, use_locking=False): + """Subtracts `IndexedSlices` from this variable. + + This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices, + sparse_delta.values)`. + + Args: + sparse_delta: `IndexedSlices` to be subtracted from this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the scattered subtraction has completed. + + Raises: + ValueError: if `sparse_delta` is not an `IndexedSlices`. + """ + raise NotImplementedError + + def count_up_to(self, limit): + """Increments this variable until it reaches `limit`. + + When that Op is run it tries to increment the variable by `1`. If + incrementing the variable would bring it above `limit` then the Op raises + the exception `OutOfRangeError`. + + If no error is raised, the Op outputs the value of the variable before + the increment. + + This is essentially a shortcut for `count_up_to(self, limit)`. + + Args: + limit: value at which incrementing the variable raises an error. + + Returns: + A `Tensor` that will hold the variable value before the increment. If no + other Op modifies this variable, the values produced will all be + distinct. + """ + raise NotImplementedError + + def load(self, value, session=None): + """Load new value into this variable. + + Writes new value to variable's memory. Doesn't add ops to the graph. + + This convenience method requires a session where the graph + containing this variable has been launched. If no session is + passed, the default session is used. See @{tf.Session} for more + information on launching a graph and on sessions. + + ```python + v = tf.Variable([1, 2]) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + # Usage passing the session explicitly. + v.load([2, 3], sess) + print(v.eval(sess)) # prints [2 3] + # Usage with the default session. The 'with' block + # above makes 'sess' the default session. + v.load([3, 4], sess) + print(v.eval()) # prints [3 4] + ``` + + Args: + value: New variable value + session: The session to use to evaluate this variable. If + none, the default session is used. + + Raises: + ValueError: Session is not passed and no default session + """ + raise NotImplementedError + + # Conversion to tensor. + @staticmethod + def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name + """Utility function for converting a Variable to a Tensor.""" + _ = name + if dtype and not dtype.is_compatible_with(v.dtype): + raise ValueError( + "Incompatible type conversion requested to type '%s' for variable " + "of type '%s'" % (dtype.name, v.dtype.name)) + if as_ref: + return v._ref() # pylint: disable=protected-access + else: + return v.value() + + @staticmethod + def _OverloadAllOperators(): # pylint: disable=invalid-name + """Register overloads for all operators.""" + for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + Variable._OverloadOperator(operator) + # For slicing, bind getitem differently than a tensor (use SliceHelperVar + # instead) + # pylint: disable=protected-access + setattr(Variable, "__getitem__", array_ops._SliceHelperVar) + + @staticmethod + def _OverloadOperator(operator): # pylint: disable=invalid-name + """Defer an operator overload to `ops.Tensor`. + + We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + + Args: + operator: string. The operator name. + """ + + def _run_op(a, *args): + # pylint: disable=protected-access + return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + # Propagate __doc__ to wrapper + try: + _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + except AttributeError: + pass + + setattr(Variable, operator, _run_op) + + # NOTE(mrry): This enables the Variable's overloaded "right" binary + # operators to run when the left operand is an ndarray, because it + # accords the Variable class higher priority than an ndarray, or a + # numpy matrix. + # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ + # mechanism, which allows more control over how Variables interact + # with ndarrays. + __array_priority__ = 100 + + @property + def name(self): + """The name of this variable.""" + raise NotImplementedError + + @property + def initializer(self): + """The initializer operation for this variable.""" + raise NotImplementedError + + @property + def device(self): + """The device of this variable.""" + raise NotImplementedError + + @property + def dtype(self): + """The `DType` of this variable.""" + raise NotImplementedError + + @property + def op(self): + """The `Operation` of this variable.""" + raise NotImplementedError + + @property + def graph(self): + """The `Graph` of this variable.""" + raise NotImplementedError + + @property + def shape(self): + """The `TensorShape` of this variable. + + Returns: + A `TensorShape`. + """ + raise NotImplementedError + + def get_shape(self): + """Alias of Variable.shape.""" + raise NotImplementedError + + def to_proto(self, export_scope=None): + """Converts a `Variable` to a `VariableDef` protocol buffer. + + Args: + export_scope: Optional `string`. Name scope to remove. + + Returns: + A `VariableDef` protocol buffer, or `None` if the `Variable` is not + in the specified name scope. + """ + raise NotImplementedError + + @staticmethod + def from_proto(variable_def, import_scope=None): + """Returns a `Variable` object created from `variable_def`.""" + return RefVariable(variable_def=variable_def, + import_scope=import_scope) + + class SaveSliceInfo(object): + """Information on how to save this Variable as a slice. + + Provides internal support for saving variables as slices of a larger + variable. This API is not public and is subject to change. + + Available properties: + + * full_name + * full_shape + * var_offset + * var_shape + """ + + def __init__(self, + full_name=None, + full_shape=None, + var_offset=None, + var_shape=None, + save_slice_info_def=None, + import_scope=None): + """Create a `SaveSliceInfo`. + + Args: + full_name: Name of the full variable of which this `Variable` is a + slice. + full_shape: Shape of the full variable, as a list of int. + var_offset: Offset of this `Variable` into the full variable, as a + list of int. + var_shape: Shape of this `Variable`, as a list of int. + save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, + recreates the SaveSliceInfo object its contents. + `save_slice_info_def` and other arguments are mutually + exclusive. + import_scope: Optional `string`. Name scope to add. Only used + when initializing from protocol buffer. + """ + if save_slice_info_def: + assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) + self.full_name = ops.prepend_name_scope( + save_slice_info_def.full_name, import_scope=import_scope) + self.full_shape = [i for i in save_slice_info_def.full_shape] + self.var_offset = [i for i in save_slice_info_def.var_offset] + self.var_shape = [i for i in save_slice_info_def.var_shape] + else: + self.full_name = full_name + self.full_shape = full_shape + self.var_offset = var_offset + self.var_shape = var_shape + + @property + def spec(self): + """Computes the spec string used for saving.""" + full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " + sl_spec = ":".join([ + "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) + ]) + return full_shape_str + sl_spec + + def to_proto(self, export_scope=None): + """Returns a SaveSliceInfoDef() proto. + + Args: + export_scope: Optional `string`. Name scope to remove. + + Returns: + A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not + in the specified name scope. + """ + if (export_scope is None or + self.full_name.startswith(export_scope)): + save_slice_info_def = variable_pb2.SaveSliceInfoDef() + save_slice_info_def.full_name = ops.strip_name_scope( + self.full_name, export_scope) + for i in self.full_shape: + save_slice_info_def.full_shape.append(i) + for i in self.var_offset: + save_slice_info_def.var_offset.append(i) + for i in self.var_shape: + save_slice_info_def.var_shape.append(i) + return save_slice_info_def + else: + return None + + def __iadd__(self, other): + raise NotImplementedError + + def __isub__(self, other): + raise NotImplementedError + + def __imul__(self, other): + raise NotImplementedError + + def __idiv__(self, other): + raise NotImplementedError + + def __itruediv__(self, other): + raise NotImplementedError + + def __irealdiv__(self, other): + raise NotImplementedError + + def __ipow__(self, other): + raise NotImplementedError + + +# TODO(apassos): do not repeat all comments here +class RefVariable(Variable): + """Ref-based implementation of variables.""" + + def __init__(self, + initial_value=None, + trainable=True, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + variable_def=None, + dtype=None, + expected_shape=None, + import_scope=None, + constraint=None): + """Creates a new variable with value `initial_value`. + + The new variable is added to the graph collections listed in `collections`, + which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + + If `trainable` is `True` the variable is also added to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + This constructor creates both a `variable` Op and an `assign` Op to set the + variable to its initial value. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: If `False`, allows the variable to be initialized with a + value of unknown shape. If `True`, the default, the shape of + `initial_value` must be known. + caching_device: Optional device string describing where the Variable + should be cached for reading. Defaults to the Variable's device. + If not `None`, caches on another device. Typical use is to cache + on the device where the Ops using the Variable reside, to deduplicate + copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + variable_def: `VariableDef` protocol buffer. If not `None`, recreates + the Variable object with its contents, referencing the variable's nodes + in the graph, which must already exist. The graph is not changed. + `variable_def` and the other arguments are mutually exclusive. + dtype: If set, initial_value will be converted to the given type. + If `None`, either the datatype will be kept (if `initial_value` is + a Tensor), or `convert_to_tensor` will decide. + expected_shape: A TensorShape. If set, initial_value is expected + to have this shape. + import_scope: Optional `string`. Name scope to add to the + `Variable.` Only used when initializing from protocol buffer. + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + + Raises: + ValueError: If both `variable_def` and initial_value are specified. + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + RuntimeError: If eager execution is enabled. """ - if context.executing_eagerly(): - raise RuntimeError( - "tf.Variable not supported when eager execution is enabled. " - "Please use tf.contrib.eager.Variable instead") self._in_graph_mode = True if variable_def: # If variable_def is provided, recreates the variable from its fields. @@ -348,8 +984,7 @@ class Variable(checkpointable.CheckpointableBase): # Ensure that we weren't lifted into the eager context. if context.executing_eagerly(): raise RuntimeError( - "tf.Variable not supported when eager execution is enabled. " - "Please use tf.contrib.eager.Variable instead") + "RefVariable not supported when eager execution is enabled. ") with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -1068,12 +1703,6 @@ class Variable(checkpointable.CheckpointableBase): else: return None - @staticmethod - def from_proto(variable_def, import_scope=None): - """Returns a `Variable` object created from `variable_def`.""" - return Variable(variable_def=variable_def, - import_scope=import_scope) - def __iadd__(self, other): logging.log_first_n( logging.WARN, @@ -1130,90 +1759,6 @@ class Variable(checkpointable.CheckpointableBase): " if you want a new python Tensor object.", 1) return self ** other - class SaveSliceInfo(object): - """Information on how to save this Variable as a slice. - - Provides internal support for saving variables as slices of a larger - variable. This API is not public and is subject to change. - - Available properties: - - * full_name - * full_shape - * var_offset - * var_shape - """ - - def __init__(self, - full_name=None, - full_shape=None, - var_offset=None, - var_shape=None, - save_slice_info_def=None, - import_scope=None): - """Create a `SaveSliceInfo`. - - Args: - full_name: Name of the full variable of which this `Variable` is a - slice. - full_shape: Shape of the full variable, as a list of int. - var_offset: Offset of this `Variable` into the full variable, as a - list of int. - var_shape: Shape of this `Variable`, as a list of int. - save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, - recreates the SaveSliceInfo object its contents. - `save_slice_info_def` and other arguments are mutually - exclusive. - import_scope: Optional `string`. Name scope to add. Only used - when initializing from protocol buffer. - """ - if save_slice_info_def: - assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) - self.full_name = ops.prepend_name_scope( - save_slice_info_def.full_name, import_scope=import_scope) - self.full_shape = [i for i in save_slice_info_def.full_shape] - self.var_offset = [i for i in save_slice_info_def.var_offset] - self.var_shape = [i for i in save_slice_info_def.var_shape] - else: - self.full_name = full_name - self.full_shape = full_shape - self.var_offset = var_offset - self.var_shape = var_shape - - @property - def spec(self): - """Computes the spec string used for saving.""" - full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " - sl_spec = ":".join([ - "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) - ]) - return full_shape_str + sl_spec - - def to_proto(self, export_scope=None): - """Returns a SaveSliceInfoDef() proto. - - Args: - export_scope: Optional `string`. Name scope to remove. - - Returns: - A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not - in the specified name scope. - """ - if (export_scope is None or - self.full_name.startswith(export_scope)): - save_slice_info_def = variable_pb2.SaveSliceInfoDef() - save_slice_info_def.full_name = ops.strip_name_scope( - self.full_name, export_scope) - for i in self.full_shape: - save_slice_info_def.full_shape.append(i) - for i in self.var_offset: - save_slice_info_def.var_offset.append(i) - for i in self.var_shape: - save_slice_info_def.var_shape.append(i) - return save_slice_info_def - else: - return None - def _set_save_slice_info(self, save_slice_info): """Sets the slice info for this `Variable`. @@ -1230,7 +1775,7 @@ class PartitionedVariable(object): """A container for partitioned `Variable` objects. @compatibility(eager) `tf.PartitionedVariable` is not compatible with - eager execution. Use `tfe.Variable` instead which is compatible + eager execution. Use `tf.Variable` instead which is compatible with both eager execution and graph construction. See [the TensorFlow Eager Execution guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) @@ -1404,6 +1949,10 @@ class PartitionedVariable(object): def dtype(self): return self._dtype + @property + def shape(self): + return self.get_shape() + def get_shape(self): return self._shape @@ -1723,6 +2272,8 @@ def report_uninitialized_variables(var_list=None, var_list.append(op.outputs[0]) with ops.name_scope(name): # Run all operations on CPU + if var_list: + init_vars = [state_ops.is_variable_initialized(v) for v in var_list] with ops.device("/cpu:0"): if not var_list: # Return an empty tensor so we only need to check for returned tensor @@ -1730,9 +2281,7 @@ def report_uninitialized_variables(var_list=None, return array_ops.constant([], dtype=dtypes.string) else: # Get a 1-D boolean tensor listing whether each variable is initialized. - variables_mask = math_ops.logical_not( - array_ops.stack( - [state_ops.is_variable_initialized(v) for v in var_list])) + variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) # Get a 1-D string tensor containing all the variable names. variable_names_tensor = array_ops.constant( [s.op.name for s in var_list]) |