diff options
author | 2016-12-14 09:59:54 -0800 | |
---|---|---|
committer | 2016-12-14 10:06:39 -0800 | |
commit | ece139170fee56df1e185682ba800252924b0fc5 (patch) | |
tree | 4b6ad0d0fac807236e47dd15b91c108e1c84c718 | |
parent | 16f656bce63cf89adb7ccb744dd3999e52c8341b (diff) |
Changes ResourceVariable constructor to be more similar to Variable.
Also updates Variable to remove now-unnecessary check.
Change: 142030225
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 33 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 220 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 3 | ||||
-rw-r--r-- | tensorflow/python/training/optimizer.py | 18 | ||||
-rw-r--r-- | tensorflow/python/training/training_util.py | 5 |
7 files changed, 203 insertions, 83 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index b426719912..b38f7397d7 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -19,10 +19,12 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -95,6 +97,37 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]]) + def testInitFn(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, + dtype=dtypes.float32) + self.assertEqual(v.handle.op.colocation_groups(), + v.initializer.inputs[1].op.colocation_groups()) + + def testInitFnDtype(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, + dtype=dtypes.float32) + self.assertEqual(dtypes.float32, v.value.dtype) + + def testInitFnNoDtype(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1) + self.assertEqual(dtypes.int32, v.value.dtype) + + def testInitializeAllVariables(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32) + with self.assertRaises(errors.NotFoundError): + v.value.eval() + variables.global_variables_initializer().run() + self.assertEqual(1.0, v.value.eval()) + + def testOperatorOverload(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + self.assertEqual(2.0, (v+v).eval()) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index ac50853c64..fa7071d042 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -328,9 +328,6 @@ class VariablesTestCase(tf.test.TestCase): shape = [2, 1] with self.test_session(): initializer = lambda: tf.constant(value) - with self.assertRaises(ValueError): - # Checks that dtype must be specified. - tf.Variable(initializer) v1 = tf.Variable(initializer, dtype=tf.float32) self.assertEqual(shape, v1.get_shape()) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index bdb777fddf..dac8fa815e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -19,15 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resources # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_resource_variable_ops import * # pylint: enable=wildcard-import +from tensorflow.python.util import compat def _register_variable_read(read, collections, trainable): @@ -51,79 +51,111 @@ class ResourceVariable(object): # pylint: disable=unused-argument def __init__(self, initial_value=None, - name=None, - caching_device=None, trainable=True, collections=None, - dtype=None, - shape=None): + validate_shape=True, + caching_device=None, + name=None, + dtype=None): """Creates a variable. Args: - initial_value: A `Tensor` or Python object convertible to a `Tensor` - representing the initial value of this variable. - name: The name of this variable. Automatically uniquified. - caching_device: device where the variable value's read by default. - trainable: Whether the global read of this variable will be used for - training. - collections: Additional collections to which the `read` operation for - this variable is to be added. Defaults to []. - dtype: The type of this variable. Can be omitted if it can be deduced - from the initial_value. If different from the type of the initial - value it will be cast to this type. - shape: The shape of this variable. Only specify if there is no initial - value but shape inference is desired. + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: Ignored. Provided for compatibility with tf.Variable. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. """ - if initial_value is not None: - if callable(initial_value): - initial_value = initial_value() - initial_value = ops.convert_to_tensor(initial_value) - if dtype is None: - assert initial_value is not None, ("Trying to create a resource variable " - "with no dtype or initial value. At" - " least one of these must be set.") - dtype = initial_value.dtype - elif initial_value is not None: - initial_value = math_ops.cast(initial_value, dtype) - if shape is None: - if initial_value is not None: - shape = initial_value.get_shape().as_proto() - else: - shape = tensor_shape.unknown_shape() - else: - shape = tensor_shape.as_shape(shape) - - self._dtype = dtype - with ops.name_scope(name, "Variable", [initial_value]) as name: - self._handle = gen_resource_variable_ops.var_handle_op(shared_name=name, - name=name, - dtype=dtype, - shape=shape) - - with ops.name_scope("IsInitialized"): - self._is_initialized_op = ( - gen_resource_variable_ops.var_is_initialized_op(self._handle)) - if initial_value is not None: - with ops.name_scope("Create"): - self._initialize_op = gen_resource_variable_ops.assign_variable_op( - self._handle, initial_value) - resources.register_resource(self._handle, - self._initialize_op, - self._is_initialized_op) - - with ops.name_scope("Read"): - if caching_device is not None: - with ops.device(caching_device): - self._value = gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + if not isinstance(collections, (list, tuple, set)): + raise ValueError( + "collections argument to Variable constructor must be a list, tuple, " + "or set. Got %s of type %s" % (collections, type(collections))) + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: + collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] + with ops.control_dependencies(None): + with ops.name_scope(name, "Variable", [] if init_from_fn else + [initial_value]) as name: + if init_from_fn: + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + # pylint: disable=protected-access + true_name = ops._name_from_scope_name(name) + attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + s=[compat.as_bytes("loc:@%s" % true_name)])) + # pylint: disable=protected-access + with ops.get_default_graph()._attr_scope({"_class": attr}): + with ops.name_scope("Initializer"), ops.device(None): + self._initial_value = ops.convert_to_tensor( + initial_value(), name="initial_value", dtype=dtype) + self._handle = gen_resource_variable_ops.var_handle_op( + shape=self._initial_value.get_shape(), + dtype=self._initial_value.dtype.base_dtype, + shared_name=name, name=name) + + # Or get the initial value from a Tensor or Python object. else: + self._initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + self._handle = gen_resource_variable_ops.var_handle_op( + shape=self._initial_value.get_shape(), + dtype=self._initial_value.dtype.base_dtype, + shared_name=name, name=name) + + self._dtype = self._initial_value.dtype.base_dtype + + with ops.name_scope("IsInitialized"): + self._is_initialized_op = ( + gen_resource_variable_ops.var_is_initialized_op(self._handle)) + if initial_value is not None: + with ops.name_scope("Create"): + self._initialize_op = gen_resource_variable_ops.assign_variable_op( + self._handle, self._initial_value) + + with ops.name_scope("Read"): self._value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) - # TODO(apassos) this is terrible - self._value.initializer = self._initialize_op - _register_variable_read( - self._value, trainable=trainable, collections=collections) + if caching_device is not None: + with ops.device(caching_device): + self._cached_value = array_ops.identity(self._value) + else: + with ops.colocate_with(self._handle.op): + self._cached_value = array_ops.identity(self._value) + # TODO(apassos) this is terrible monkey-patching required to make + # initialize_all_variables work. Replace self._value with an explicit + # class instead of monkey-patching. + self._value.initializer = self._initialize_op + ops.add_to_collections(collections, self) @property def dtype(self): @@ -152,7 +184,7 @@ class ResourceVariable(object): @property def value(self): """A cached operation which reads the value of this variable.""" - return self._value + return self._cached_value def _as_graph_element(self): """Conversion function for Graph.as_graph_element().""" @@ -165,8 +197,8 @@ class ResourceVariable(object): @property def op(self): - """The op which reads the value of this variable.""" - return self._value.op + """The op for this variable.""" + return self._handle.op def eval(self, session=None): """Evaluates and returns the value of this variable.""" @@ -189,7 +221,7 @@ class ResourceVariable(object): value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) _register_variable_read(value, collections=collections, trainable=trainable) - return value + return array_ops.identity(value) def sparse_read(self, indices, collections=None, trainable=True, name=None): """Reads the value of this variable sparsely, using `gather`.""" @@ -197,16 +229,58 @@ class ResourceVariable(object): value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype) _register_variable_read(value, collections=collections, trainable=trainable) - return value + return array_ops.identity(value) + + @staticmethod + def _OverloadAllOperators(): # pylint: disable=invalid-name + """Register overloads for all operators.""" + for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + ResourceVariable._OverloadOperator(operator) + # For slicing, bind getitem differently than a tensor (use SliceHelperVar + # instead) + # pylint: disable=protected-access + setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar) + + def _AsTensor(self): + return self.value + + @staticmethod + def _OverloadOperator(operator): # pylint: disable=invalid-name + """Defer an operator overload to `ops.Tensor`. + We pull the operator out of ops.Tensor dynamically to avoid ordering issues. -# pylint: disable=unused-argument + Args: + operator: string. The operator name. + """ + + def _run_op(a, *args): + # pylint: disable=protected-access + return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + # Propagate __doc__ to wrapper + try: + _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + except AttributeError: + pass + + setattr(ResourceVariable, operator, _run_op) + + __array_priority__ = 100 + + +# pylint: disable=unused-argument,protected-access def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): if dtype is not None and dtype != var.value.dtype: print("trying to switch the dtype to ", dtype, " from ", var.value.dtype) return NotImplemented - return var.value + if as_ref: + return var._value + return var._cached_value +# pylint: enable=unused-argument,protected-access # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor) + +# pylint: disable=protected-access +ResourceVariable._OverloadAllOperators() diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index daffdf79d6..e8a79a2fe1 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -244,6 +244,6 @@ def is_variable_initialized(ref, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.is_variable_initialized(ref=ref, name=name) # Handle resource variables. - if ref.op.type == "ReadVariableOp": - return gen_resource_variable_ops.var_is_initialized_op(ref.op.inputs[0], + if ref.op.type == "VarHandleOp": + return gen_resource_variable_ops.var_is_initialized_op(ref.handle, name=name) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 69134f8a74..89a19f377d 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -274,9 +274,6 @@ class Variable(object): if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) - if init_from_fn and dtype is None: - raise ValueError( - "dtype must also be specified when initial_value is callable.") if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 3166864cbe..0e8f55f22c 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -63,7 +63,7 @@ class _RefVariableProcessor(_OptimizableVariable): return optimizer._apply_sparse(g, self._v) # pylint: disable=protected-access -class _DenseResourceVariableProcessor(_OptimizableVariable): +class _DenseReadResourceVariableProcessor(_OptimizableVariable): """Processor for dense ResourceVariables.""" def __init__(self, v): @@ -77,6 +77,20 @@ class _DenseResourceVariableProcessor(_OptimizableVariable): return optimizer._resource_apply_dense(g, self._v.op.inputs[0]) +class _DenseResourceVariableProcessor(_OptimizableVariable): + """Processor for dense ResourceVariables.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g): + # pylint: disable=protected-access + return optimizer._resource_apply_dense(g, self._v.handle) + + class _SparseResourceVariableProcessor(_OptimizableVariable): """Processor for sparse ResourceVariables.""" @@ -96,6 +110,8 @@ def _get_processor(v): if isinstance(v, variables.Variable): return _RefVariableProcessor(v) if v.op.type == "ReadVariableOp": + return _DenseReadResourceVariableProcessor(v) + if v.op.type == "VarHandleOp": return _DenseResourceVariableProcessor(v) if v.op.type == "ResourceGather": return _SparseResourceVariableProcessor(v) diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index a8a1c5acf0..14d20cf2a6 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -95,7 +96,9 @@ def assert_global_step(global_step_tensor): global_step_tensor: `Tensor` to test. """ if not (isinstance(global_step_tensor, variables.Variable) or - isinstance(global_step_tensor, ops.Tensor)): + isinstance(global_step_tensor, ops.Tensor) or + isinstance(global_step_tensor, + resource_variable_ops.ResourceVariable)): raise TypeError( 'Existing "global_step" must be a Variable or Tensor: %s.' % global_step_tensor) |