diff options
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 150 |
1 files changed, 139 insertions, 11 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 15cafbbde5..8b259b6b6b 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -181,7 +181,8 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None): name=name) -class ResourceVariable(variables.Variable): +# TODO(apassos) make this be variables.Variable +class ResourceVariable(variables.RefVariable): """Variable based on resource handles. See the @{$variables$Variables How To} for a high level overview. @@ -195,15 +196,16 @@ class ResourceVariable(variables.Variable): the variable are fixed. The value can be changed using one of the assign methods. - Just like any `Tensor`, variables created with `ResourceVariable()` can be - used as inputs for other Ops in the graph. Additionally, all the operators - overloaded for the `Tensor` class are carried over to variables, so you can - also add nodes to the graph by just doing arithmetic on variables. + Just like any `Tensor`, variables created with + `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the + graph. Additionally, all the operators overloaded for the `Tensor` class are + carried over to variables, so you can also add nodes to the graph by just + doing arithmetic on variables. - Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each + Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each usage of a ResourceVariable in a TensorFlow graph adds a read_value operation - to the graph. The Tensors returned by a read_value operation are guaranteed - to see all modifications to the value of the variable which happen in any + to the graph. The Tensors returned by a read_value operation are guaranteed to + see all modifications to the value of the variable which happen in any operation on which the read_value depends on (either directly, indirectly, or via a control dependency) and guaranteed to not see any modification to the value of the variable from operations that depend on the read_value operation. @@ -217,7 +219,7 @@ class ResourceVariable(variables.Variable): can cause tf.Variable and tf.ResourceVariable to behave differently: ```python - a = tf.ResourceVariable(1.0) + a = tf.Variable(1.0, use_resource=True) a.initializer.run() assign = a.assign(2.0) @@ -741,8 +743,14 @@ class ResourceVariable(variables.Variable): def _read_variable_op(self): if self.trainable: tape.watch_variable(self) - return gen_resource_variable_ops.read_variable_op(self._handle, - self._dtype) + result = gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + if not context.executing_eagerly(): + # Note that if a control flow context is active the input of the read op + # might not actually be the handle. This line bypasses it. + tape.record_operation( + "ReadVariableOp", [result], [self._handle], lambda x: [x]) + return result def read_value(self): """Constructs an op which reads the value of this variable. @@ -867,6 +875,19 @@ class ResourceVariable(variables.Variable): __array_priority__ = 100 + def is_initialized(self, name=None): + """Checks whether a resource variable has been initialized. + + Outputs boolean scalar indicating whether the tensor has been initialized. + + Args: + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `bool`. + """ + return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): """Subtracts a value from this variable. @@ -1091,6 +1112,113 @@ class _UnreadVariable(ResourceVariable): ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor) ops.register_dense_tensor_like_type(_UnreadVariable) + +class _MixedPrecisionVariable(ResourceVariable): + """Represents a variable that can return in desired dtype when read. + + In mixed precision training, it is usually desirable to use different dtypes + for variables and computation. This class will be used to wrap created + ResourceVariable when mixed precision training is enabled. It allows layers to + perform computation in a different dtype than their variable dtypes, in order + to achieve higher performance without causing quality loss. + """ + + def __init__(self, var, read_dtype): + """Creates a MixedPrecisionVariable. + + Args: + var: A ResourceVariable instance. + read_dtype: A tf.DType, the returned dtype when read, default to None. + Casting is performed if read_dtype is not None and differs from + var.dtype. + Returns: + An MixedPrecisionVariable instance. + Raises: + ValueError: if var is not a ResourceVariable instance, or read_dtype is + not a tf.DType instance. + """ + # pylint: disable=super-init-not-called + # We do not call super init on purpose. + if not isinstance(var, ResourceVariable): + raise ValueError("InvalidArgument: var must be a ResourceVariable type.") + if not isinstance(read_dtype, dtypes.DType): + raise ValueError("InvalidArgument: read_dtype must be a tf.DType type.") + + self._var = var + self._trainable = var.trainable + self._save_slice_info = None + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + self._in_graph_mode = var._in_graph_mode # pylint: disable=protected-access + self._handle = var.handle + self._shape = var.shape + self._initial_value = None + if isinstance(self.handle, ops.EagerTensor): + self._handle_name = "" + else: + self._handle_name = self.handle.name + self._unique_id = var._unique_id # pylint: disable=protected-access + self._dtype = var.dtype + self._constraint = None + self._cached_value = None + self._is_initialized_op = var._is_initialized_op # pylint: disable=protected-access + self._initializer_op = var._initializer_op # pylint: disable=protected-access + # This needs to be set before read_value() is called. + self._read_dtype = read_dtype + if context.executing_eagerly(): + self._graph_element = None + else: + self._graph_element = self.read_value() + self._handle_deleter = ( + var._handle_deleter if not self._in_graph_mode # pylint: disable=protected-access + else None) + # pylint: enable=super-init-not-called + + @property + def name(self): + return self._var.name + + def value(self): + return self._read_variable_op() + + def read_value(self): + return self._read_variable_op() + + def _read_variable_op(self): + with ops.colocate_with(self._handle): + res = gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + if self._read_dtype != self._dtype: + return math_ops.cast(res, self._read_dtype) + else: + return res + + def set_shape(self, shape): + self._shape = shape + self._cached_shape_as_list = None + + @property + def op(self): + """The op for this variable.""" + return self._var.op + + @property + def read_dtype(self): + """The dtype of the returned tensor when reading the var.""" + return self._read_dtype + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + del name + dtype = dtype or self.read_dtype + if dtype != self.read_dtype or as_ref: + return NotImplemented + else: + res = self.value() + return res + + def _should_act_as_resource_variable(self): + """To pass resource_variable_ops.is_resource_variable check.""" + pass + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. |