aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-14 09:59:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-14 10:06:39 -0800
commitece139170fee56df1e185682ba800252924b0fc5 (patch)
tree4b6ad0d0fac807236e47dd15b91c108e1c84c718
parent16f656bce63cf89adb7ccb744dd3999e52c8341b (diff)
Changes ResourceVariable constructor to be more similar to Variable.
Also updates Variable to remove now-unnecessary check. Change: 142030225
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py33
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py3
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py220
-rw-r--r--tensorflow/python/ops/state_ops.py4
-rw-r--r--tensorflow/python/ops/variables.py3
-rw-r--r--tensorflow/python/training/optimizer.py18
-rw-r--r--tensorflow/python/training/training_util.py5
7 files changed, 203 insertions, 83 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index b426719912..b38f7397d7 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -19,10 +19,12 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -95,6 +97,37 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(read.eval(), [[3]])
+ def testInitFn(self):
+ with self.test_session():
+ v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
+ dtype=dtypes.float32)
+ self.assertEqual(v.handle.op.colocation_groups(),
+ v.initializer.inputs[1].op.colocation_groups())
+
+ def testInitFnDtype(self):
+ with self.test_session():
+ v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
+ dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, v.value.dtype)
+
+ def testInitFnNoDtype(self):
+ with self.test_session():
+ v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1)
+ self.assertEqual(dtypes.int32, v.value.dtype)
+
+ def testInitializeAllVariables(self):
+ with self.test_session():
+ v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32)
+ with self.assertRaises(errors.NotFoundError):
+ v.value.eval()
+ variables.global_variables_initializer().run()
+ self.assertEqual(1.0, v.value.eval())
+
+ def testOperatorOverload(self):
+ with self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ self.assertEqual(2.0, (v+v).eval())
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index ac50853c64..fa7071d042 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -328,9 +328,6 @@ class VariablesTestCase(tf.test.TestCase):
shape = [2, 1]
with self.test_session():
initializer = lambda: tf.constant(value)
- with self.assertRaises(ValueError):
- # Checks that dtype must be specified.
- tf.Variable(initializer)
v1 = tf.Variable(initializer, dtype=tf.float32)
self.assertEqual(shape, v1.get_shape())
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index bdb777fddf..dac8fa815e 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -19,15 +19,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resources
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.util import compat
def _register_variable_read(read, collections, trainable):
@@ -51,79 +51,111 @@ class ResourceVariable(object):
# pylint: disable=unused-argument
def __init__(self,
initial_value=None,
- name=None,
- caching_device=None,
trainable=True,
collections=None,
- dtype=None,
- shape=None):
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ dtype=None):
"""Creates a variable.
Args:
- initial_value: A `Tensor` or Python object convertible to a `Tensor`
- representing the initial value of this variable.
- name: The name of this variable. Automatically uniquified.
- caching_device: device where the variable value's read by default.
- trainable: Whether the global read of this variable will be used for
- training.
- collections: Additional collections to which the `read` operation for
- this variable is to be added. Defaults to [].
- dtype: The type of this variable. Can be omitted if it can be deduced
- from the initial_value. If different from the type of the initial
- value it will be cast to this type.
- shape: The shape of this variable. Only specify if there is no initial
- value but shape inference is desired.
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: Ignored. Provided for compatibility with tf.Variable.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device. If not `None`, caches on another device. Typical use is to
+ cache on the device where the Ops using the Variable reside, to
+ deduplicate copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
"""
- if initial_value is not None:
- if callable(initial_value):
- initial_value = initial_value()
- initial_value = ops.convert_to_tensor(initial_value)
- if dtype is None:
- assert initial_value is not None, ("Trying to create a resource variable "
- "with no dtype or initial value. At"
- " least one of these must be set.")
- dtype = initial_value.dtype
- elif initial_value is not None:
- initial_value = math_ops.cast(initial_value, dtype)
- if shape is None:
- if initial_value is not None:
- shape = initial_value.get_shape().as_proto()
- else:
- shape = tensor_shape.unknown_shape()
- else:
- shape = tensor_shape.as_shape(shape)
-
- self._dtype = dtype
- with ops.name_scope(name, "Variable", [initial_value]) as name:
- self._handle = gen_resource_variable_ops.var_handle_op(shared_name=name,
- name=name,
- dtype=dtype,
- shape=shape)
-
- with ops.name_scope("IsInitialized"):
- self._is_initialized_op = (
- gen_resource_variable_ops.var_is_initialized_op(self._handle))
- if initial_value is not None:
- with ops.name_scope("Create"):
- self._initialize_op = gen_resource_variable_ops.assign_variable_op(
- self._handle, initial_value)
- resources.register_resource(self._handle,
- self._initialize_op,
- self._is_initialized_op)
-
- with ops.name_scope("Read"):
- if caching_device is not None:
- with ops.device(caching_device):
- self._value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ if not isinstance(collections, (list, tuple, set)):
+ raise ValueError(
+ "collections argument to Variable constructor must be a list, tuple, "
+ "or set. Got %s of type %s" % (collections, type(collections)))
+ if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
+ collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
+ with ops.control_dependencies(None):
+ with ops.name_scope(name, "Variable", [] if init_from_fn else
+ [initial_value]) as name:
+ if init_from_fn:
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ # pylint: disable=protected-access
+ true_name = ops._name_from_scope_name(name)
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % true_name)]))
+ # pylint: disable=protected-access
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
+ self._initial_value = ops.convert_to_tensor(
+ initial_value(), name="initial_value", dtype=dtype)
+ self._handle = gen_resource_variable_ops.var_handle_op(
+ shape=self._initial_value.get_shape(),
+ dtype=self._initial_value.dtype.base_dtype,
+ shared_name=name, name=name)
+
+ # Or get the initial value from a Tensor or Python object.
else:
+ self._initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
+ self._handle = gen_resource_variable_ops.var_handle_op(
+ shape=self._initial_value.get_shape(),
+ dtype=self._initial_value.dtype.base_dtype,
+ shared_name=name, name=name)
+
+ self._dtype = self._initial_value.dtype.base_dtype
+
+ with ops.name_scope("IsInitialized"):
+ self._is_initialized_op = (
+ gen_resource_variable_ops.var_is_initialized_op(self._handle))
+ if initial_value is not None:
+ with ops.name_scope("Create"):
+ self._initialize_op = gen_resource_variable_ops.assign_variable_op(
+ self._handle, self._initial_value)
+
+ with ops.name_scope("Read"):
self._value = gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype)
- # TODO(apassos) this is terrible
- self._value.initializer = self._initialize_op
- _register_variable_read(
- self._value, trainable=trainable, collections=collections)
+ if caching_device is not None:
+ with ops.device(caching_device):
+ self._cached_value = array_ops.identity(self._value)
+ else:
+ with ops.colocate_with(self._handle.op):
+ self._cached_value = array_ops.identity(self._value)
+ # TODO(apassos) this is terrible monkey-patching required to make
+ # initialize_all_variables work. Replace self._value with an explicit
+ # class instead of monkey-patching.
+ self._value.initializer = self._initialize_op
+ ops.add_to_collections(collections, self)
@property
def dtype(self):
@@ -152,7 +184,7 @@ class ResourceVariable(object):
@property
def value(self):
"""A cached operation which reads the value of this variable."""
- return self._value
+ return self._cached_value
def _as_graph_element(self):
"""Conversion function for Graph.as_graph_element()."""
@@ -165,8 +197,8 @@ class ResourceVariable(object):
@property
def op(self):
- """The op which reads the value of this variable."""
- return self._value.op
+ """The op for this variable."""
+ return self._handle.op
def eval(self, session=None):
"""Evaluates and returns the value of this variable."""
@@ -189,7 +221,7 @@ class ResourceVariable(object):
value = gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype)
_register_variable_read(value, collections=collections, trainable=trainable)
- return value
+ return array_ops.identity(value)
def sparse_read(self, indices, collections=None, trainable=True, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
@@ -197,16 +229,58 @@ class ResourceVariable(object):
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype)
_register_variable_read(value, collections=collections, trainable=trainable)
- return value
+ return array_ops.identity(value)
+
+ @staticmethod
+ def _OverloadAllOperators(): # pylint: disable=invalid-name
+ """Register overloads for all operators."""
+ for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
+ ResourceVariable._OverloadOperator(operator)
+ # For slicing, bind getitem differently than a tensor (use SliceHelperVar
+ # instead)
+ # pylint: disable=protected-access
+ setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar)
+
+ def _AsTensor(self):
+ return self.value
+
+ @staticmethod
+ def _OverloadOperator(operator): # pylint: disable=invalid-name
+ """Defer an operator overload to `ops.Tensor`.
+ We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
-# pylint: disable=unused-argument
+ Args:
+ operator: string. The operator name.
+ """
+
+ def _run_op(a, *args):
+ # pylint: disable=protected-access
+ return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
+ # Propagate __doc__ to wrapper
+ try:
+ _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
+ except AttributeError:
+ pass
+
+ setattr(ResourceVariable, operator, _run_op)
+
+ __array_priority__ = 100
+
+
+# pylint: disable=unused-argument,protected-access
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
if dtype is not None and dtype != var.value.dtype:
print("trying to switch the dtype to ", dtype, " from ", var.value.dtype)
return NotImplemented
- return var.value
+ if as_ref:
+ return var._value
+ return var._cached_value
+# pylint: enable=unused-argument,protected-access
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
+
+# pylint: disable=protected-access
+ResourceVariable._OverloadAllOperators()
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index daffdf79d6..e8a79a2fe1 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -244,6 +244,6 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
- if ref.op.type == "ReadVariableOp":
- return gen_resource_variable_ops.var_is_initialized_op(ref.op.inputs[0],
+ if ref.op.type == "VarHandleOp":
+ return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
name=name)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 69134f8a74..89a19f377d 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -274,9 +274,6 @@ class Variable(object):
if initial_value is None:
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
- if init_from_fn and dtype is None:
- raise ValueError(
- "dtype must also be specified when initial_value is callable.")
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 3166864cbe..0e8f55f22c 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -63,7 +63,7 @@ class _RefVariableProcessor(_OptimizableVariable):
return optimizer._apply_sparse(g, self._v) # pylint: disable=protected-access
-class _DenseResourceVariableProcessor(_OptimizableVariable):
+class _DenseReadResourceVariableProcessor(_OptimizableVariable):
"""Processor for dense ResourceVariables."""
def __init__(self, v):
@@ -77,6 +77,20 @@ class _DenseResourceVariableProcessor(_OptimizableVariable):
return optimizer._resource_apply_dense(g, self._v.op.inputs[0])
+class _DenseResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g):
+ # pylint: disable=protected-access
+ return optimizer._resource_apply_dense(g, self._v.handle)
+
+
class _SparseResourceVariableProcessor(_OptimizableVariable):
"""Processor for sparse ResourceVariables."""
@@ -96,6 +110,8 @@ def _get_processor(v):
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
if v.op.type == "ReadVariableOp":
+ return _DenseReadResourceVariableProcessor(v)
+ if v.op.type == "VarHandleOp":
return _DenseResourceVariableProcessor(v)
if v.op.type == "ResourceGather":
return _SparseResourceVariableProcessor(v)
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index a8a1c5acf0..14d20cf2a6 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@@ -95,7 +96,9 @@ def assert_global_step(global_step_tensor):
global_step_tensor: `Tensor` to test.
"""
if not (isinstance(global_step_tensor, variables.Variable) or
- isinstance(global_step_tensor, ops.Tensor)):
+ isinstance(global_step_tensor, ops.Tensor) or
+ isinstance(global_step_tensor,
+ resource_variable_ops.ResourceVariable)):
raise TypeError(
'Existing "global_step" must be a Variable or Tensor: %s.' %
global_step_tensor)