diff options
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 47 | ||||
-rw-r--r-- | tensorflow/python/ops/partitioned_variables.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 13 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 83 |
4 files changed, 129 insertions, 31 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 593cd5f25a..f2325779c1 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase): self.assertEqual(var.op.device, init_op.device) sess.run(init_op) + def testInitializerFunction(self): + value = [[-42], [133.7]] + shape = [2, 1] + with self.test_session(): + initializer = lambda: tf.constant(value) + with self.assertRaises(ValueError): + # Checks that dtype must be specified. + tf.Variable(initializer) + + v1 = tf.Variable(initializer, dtype=tf.float32) + self.assertEqual(shape, v1.get_shape()) + self.assertAllClose(value, v1.initial_value.eval()) + with self.assertRaises(tf.errors.FailedPreconditionError): + v1.eval() + + v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32) + self.assertEqual(v1.get_shape(), v2.get_shape()) + self.assertAllClose(np.negative(value), v2.initial_value.eval()) + + # Once v2.initial_value.eval() has been called, v1 has effectively been + # initialized. + self.assertAllClose(value, v1.eval()) + + with self.assertRaises(tf.errors.FailedPreconditionError): + v2.eval() + tf.initialize_all_variables().run() + self.assertAllClose(np.negative(value), v2.eval()) + + def testInitializerFunctionDevicePlacement(self): + with self.test_session(): + initializer = lambda: tf.constant(42.0) + with tf.device("/cpu:100"): + v1 = tf.Variable(initializer, dtype=tf.float32, name="v1") + expected_device = "/device:CPU:100" + expected_group_v1 = [b"loc:@v1"] + self.assertEqual(expected_device, v1.op.device) + self.assertEqual(expected_group_v1, v1.op.colocation_groups()) + for i in v1.initializer.inputs: + self.assertEqual(expected_device, i.op.device) + self.assertEqual(expected_group_v1, i.op.colocation_groups()) + + v2 = tf.Variable(initializer, dtype=tf.float32, name="v2") + expected_group_v2 = [b"loc:@v2"] + self.assertEqual(expected_group_v2, v2.op.colocation_groups()) + for i in v2.initializer.inputs: + self.assertEqual(expected_group_v2, i.op.colocation_groups()) + class IsInitializedTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py index 9d4d19668a..c16ba0f814 100644 --- a/tensorflow/python/ops/partitioned_variables.py +++ b/tensorflow/python/ops/partitioned_variables.py @@ -167,19 +167,22 @@ def create_partitioned_variables( slice_offset[slice_dim] += var_shape[slice_dim] if callable(initializer): - init_val = initializer(var_shape, dtype=dtype) - init_val = ops.convert_to_tensor(init_val, dtype=dtype) + init = initializer + init_shape = var_shape elif isinstance(initializer, ops.Tensor): - init_val = array_ops.slice(initializer, var_offset, var_shape) + init = array_ops.slice(initializer, var_offset, var_shape) # Use the dtype of the given tensor. - dtype = init_val.dtype.base_dtype + dtype = init.dtype.base_dtype + init_shape = None else: - init_val = ops.convert_to_tensor(initializer, dtype=dtype) - init_val = array_ops.slice(init_val, var_offset, var_shape) + init = ops.convert_to_tensor(initializer, dtype=dtype) + init = array_ops.slice(init, var_offset, var_shape) + init_shape = None var = variable_scope.get_variable(name="part_%d" % i, + shape=init_shape, dtype=dtype, - initializer=init_val, + initializer=init, trainable=trainable, collections=collections) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 64ad23674b..f60816e022 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -144,14 +144,19 @@ class _VariableStore(object): with ops.control_dependencies(None): if initializing_from_value: init_val = initializer + variable_dtype = None else: - with ops.name_scope(name + "/Initializer/"): - init_val = initializer(shape.as_list(), dtype=dtype) + init_val = lambda: initializer(shape.as_list(), dtype=dtype) + variable_dtype = dtype.base_dtype # Create the variable. - v = variables.Variable(init_val, name=name, trainable=trainable, + v = variables.Variable(initial_value=init_val, + name=name, + trainable=trainable, collections=collections, - caching_device=caching_device) + caching_device=caching_device, + dtype=variable_dtype) + self._vars[name] = v logging.info("Created variable %s with shape %s and init %s", v.name, format(shape), initializer) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index a21724194b..3f2571bc06 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -156,9 +156,12 @@ class Variable(object): variable to its initial value. Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`. - The initial value for the Variable. Must have a shape specified unless - `validate_shape` is set to False. + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. @@ -211,9 +214,12 @@ class Variable(object): """Creates a new variable from arguments. Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`. - The initial value for the Variable. Must have a shape specified unless - `validate_shape` is set to False. + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. @@ -240,25 +246,62 @@ class Variable(object): """ if initial_value is None: raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + if init_from_fn and dtype is None: + raise ValueError( + "dtype must also be specified when initial_value is callable.") + if collections is None: collections = [ops.GraphKeys.VARIABLES] if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.control_dependencies(None): - with ops.op_scope([initial_value], name, "Variable") as name: - self._initial_value = ops.convert_to_tensor(initial_value, - name="initial_value", - dtype=dtype) - initial_value_shape = self._initial_value.get_shape() - if validate_shape and not initial_value_shape.is_fully_defined(): - raise ValueError("initial_value must have a shape specified: %s" - % self._initial_value) - shape_to_set = initial_value_shape if validate_shape else [] - - self._variable = state_ops.variable_op( - shape_to_set, self._initial_value.dtype.base_dtype, - set_shape=validate_shape, name=name) - + with ops.op_scope( + [] if init_from_fn else [initial_value], name, "Variable") as name: + + # Get the initial value from a callable function. The real shape of the + # variable will be set later, since under the init_from_fn case, the + # shape won't be known until after the function is invoked. + if init_from_fn: + self._variable = state_ops.variable_op( + [], + dtype.base_dtype, + set_shape=False, + name=name) + with ops.colocate_with(self._variable.op): + with ops.name_scope("Initializer"): + # Colocate the tensors created by the initial_value() function + # with the variable itself. + self._initial_value = ops.convert_to_tensor(initial_value(), + name="initial_value", + dtype=dtype) + + # Or get the initial value from a Tensor or Python object. + else: + self._initial_value = ops.convert_to_tensor(initial_value, + name="initial_value", + dtype=dtype) + # In this case, the variable op can't be created until after the + # initial_value has been converted to a Tensor with a known type. + self._variable = state_ops.variable_op( + [], + self._initial_value.dtype.base_dtype, + set_shape=False, + name=name) + + # Manually overrides the variable's shape with the initial value's. + if validate_shape: + initial_value_shape = self._initial_value.get_shape() + if not initial_value_shape.is_fully_defined(): + raise ValueError("initial_value must have a shape specified: %s" + % self._initial_value) + self._variable.set_shape(initial_value_shape) + # TODO(b/28152992): Remove the below hack modifying the node_def shape + # directly once set_shape() handles it. + self._variable.op.node_def.attr["shape"].shape.CopyFrom( + initial_value_shape.as_proto()) + + # Assigns initial value. with ops.colocate_with(self._variable.op): self._initializer_op = state_ops.assign( self._variable, self._initial_value, |