aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/resource_variable_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py150
1 files changed, 139 insertions, 11 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 15cafbbde5..8b259b6b6b 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -181,7 +181,8 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None):
name=name)
-class ResourceVariable(variables.Variable):
+# TODO(apassos) make this be variables.Variable
+class ResourceVariable(variables.RefVariable):
"""Variable based on resource handles.
See the @{$variables$Variables How To} for a high level overview.
@@ -195,15 +196,16 @@ class ResourceVariable(variables.Variable):
the variable are fixed. The value can be changed using one of the assign
methods.
- Just like any `Tensor`, variables created with `ResourceVariable()` can be
- used as inputs for other Ops in the graph. Additionally, all the operators
- overloaded for the `Tensor` class are carried over to variables, so you can
- also add nodes to the graph by just doing arithmetic on variables.
+ Just like any `Tensor`, variables created with
+ `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
+ graph. Additionally, all the operators overloaded for the `Tensor` class are
+ carried over to variables, so you can also add nodes to the graph by just
+ doing arithmetic on variables.
- Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each
+ Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
- to the graph. The Tensors returned by a read_value operation are guaranteed
- to see all modifications to the value of the variable which happen in any
+ to the graph. The Tensors returned by a read_value operation are guaranteed to
+ see all modifications to the value of the variable which happen in any
operation on which the read_value depends on (either directly, indirectly, or
via a control dependency) and guaranteed to not see any modification to the
value of the variable from operations that depend on the read_value operation.
@@ -217,7 +219,7 @@ class ResourceVariable(variables.Variable):
can cause tf.Variable and tf.ResourceVariable to behave differently:
```python
- a = tf.ResourceVariable(1.0)
+ a = tf.Variable(1.0, use_resource=True)
a.initializer.run()
assign = a.assign(2.0)
@@ -741,8 +743,14 @@ class ResourceVariable(variables.Variable):
def _read_variable_op(self):
if self.trainable:
tape.watch_variable(self)
- return gen_resource_variable_ops.read_variable_op(self._handle,
- self._dtype)
+ result = gen_resource_variable_ops.read_variable_op(self._handle,
+ self._dtype)
+ if not context.executing_eagerly():
+ # Note that if a control flow context is active the input of the read op
+ # might not actually be the handle. This line bypasses it.
+ tape.record_operation(
+ "ReadVariableOp", [result], [self._handle], lambda x: [x])
+ return result
def read_value(self):
"""Constructs an op which reads the value of this variable.
@@ -867,6 +875,19 @@ class ResourceVariable(variables.Variable):
__array_priority__ = 100
+ def is_initialized(self, name=None):
+ """Checks whether a resource variable has been initialized.
+
+ Outputs boolean scalar indicating whether the tensor has been initialized.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `bool`.
+ """
+ return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
+
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
"""Subtracts a value from this variable.
@@ -1091,6 +1112,113 @@ class _UnreadVariable(ResourceVariable):
ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)
+
+class _MixedPrecisionVariable(ResourceVariable):
+ """Represents a variable that can return in desired dtype when read.
+
+ In mixed precision training, it is usually desirable to use different dtypes
+ for variables and computation. This class will be used to wrap created
+ ResourceVariable when mixed precision training is enabled. It allows layers to
+ perform computation in a different dtype than their variable dtypes, in order
+ to achieve higher performance without causing quality loss.
+ """
+
+ def __init__(self, var, read_dtype):
+ """Creates a MixedPrecisionVariable.
+
+ Args:
+ var: A ResourceVariable instance.
+ read_dtype: A tf.DType, the returned dtype when read, default to None.
+ Casting is performed if read_dtype is not None and differs from
+ var.dtype.
+ Returns:
+ An MixedPrecisionVariable instance.
+ Raises:
+ ValueError: if var is not a ResourceVariable instance, or read_dtype is
+ not a tf.DType instance.
+ """
+ # pylint: disable=super-init-not-called
+ # We do not call super init on purpose.
+ if not isinstance(var, ResourceVariable):
+ raise ValueError("InvalidArgument: var must be a ResourceVariable type.")
+ if not isinstance(read_dtype, dtypes.DType):
+ raise ValueError("InvalidArgument: read_dtype must be a tf.DType type.")
+
+ self._var = var
+ self._trainable = var.trainable
+ self._save_slice_info = None
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ self._in_graph_mode = var._in_graph_mode # pylint: disable=protected-access
+ self._handle = var.handle
+ self._shape = var.shape
+ self._initial_value = None
+ if isinstance(self.handle, ops.EagerTensor):
+ self._handle_name = ""
+ else:
+ self._handle_name = self.handle.name
+ self._unique_id = var._unique_id # pylint: disable=protected-access
+ self._dtype = var.dtype
+ self._constraint = None
+ self._cached_value = None
+ self._is_initialized_op = var._is_initialized_op # pylint: disable=protected-access
+ self._initializer_op = var._initializer_op # pylint: disable=protected-access
+ # This needs to be set before read_value() is called.
+ self._read_dtype = read_dtype
+ if context.executing_eagerly():
+ self._graph_element = None
+ else:
+ self._graph_element = self.read_value()
+ self._handle_deleter = (
+ var._handle_deleter if not self._in_graph_mode # pylint: disable=protected-access
+ else None)
+ # pylint: enable=super-init-not-called
+
+ @property
+ def name(self):
+ return self._var.name
+
+ def value(self):
+ return self._read_variable_op()
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def _read_variable_op(self):
+ with ops.colocate_with(self._handle):
+ res = gen_resource_variable_ops.read_variable_op(self._handle,
+ self._dtype)
+ if self._read_dtype != self._dtype:
+ return math_ops.cast(res, self._read_dtype)
+ else:
+ return res
+
+ def set_shape(self, shape):
+ self._shape = shape
+ self._cached_shape_as_list = None
+
+ @property
+ def op(self):
+ """The op for this variable."""
+ return self._var.op
+
+ @property
+ def read_dtype(self):
+ """The dtype of the returned tensor when reading the var."""
+ return self._read_dtype
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ del name
+ dtype = dtype or self.read_dtype
+ if dtype != self.read_dtype or as_ref:
+ return NotImplemented
+ else:
+ res = self.value()
+ return res
+
+ def _should_act_as_resource_variable(self):
+ """To pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.