aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py47
-rw-r--r--tensorflow/python/ops/partitioned_variables.py17
-rw-r--r--tensorflow/python/ops/variable_scope.py13
-rw-r--r--tensorflow/python/ops/variables.py83
4 files changed, 129 insertions, 31 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 593cd5f25a..f2325779c1 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase):
self.assertEqual(var.op.device, init_op.device)
sess.run(init_op)
+ def testInitializerFunction(self):
+ value = [[-42], [133.7]]
+ shape = [2, 1]
+ with self.test_session():
+ initializer = lambda: tf.constant(value)
+ with self.assertRaises(ValueError):
+ # Checks that dtype must be specified.
+ tf.Variable(initializer)
+
+ v1 = tf.Variable(initializer, dtype=tf.float32)
+ self.assertEqual(shape, v1.get_shape())
+ self.assertAllClose(value, v1.initial_value.eval())
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ v1.eval()
+
+ v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32)
+ self.assertEqual(v1.get_shape(), v2.get_shape())
+ self.assertAllClose(np.negative(value), v2.initial_value.eval())
+
+ # Once v2.initial_value.eval() has been called, v1 has effectively been
+ # initialized.
+ self.assertAllClose(value, v1.eval())
+
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ v2.eval()
+ tf.initialize_all_variables().run()
+ self.assertAllClose(np.negative(value), v2.eval())
+
+ def testInitializerFunctionDevicePlacement(self):
+ with self.test_session():
+ initializer = lambda: tf.constant(42.0)
+ with tf.device("/cpu:100"):
+ v1 = tf.Variable(initializer, dtype=tf.float32, name="v1")
+ expected_device = "/device:CPU:100"
+ expected_group_v1 = [b"loc:@v1"]
+ self.assertEqual(expected_device, v1.op.device)
+ self.assertEqual(expected_group_v1, v1.op.colocation_groups())
+ for i in v1.initializer.inputs:
+ self.assertEqual(expected_device, i.op.device)
+ self.assertEqual(expected_group_v1, i.op.colocation_groups())
+
+ v2 = tf.Variable(initializer, dtype=tf.float32, name="v2")
+ expected_group_v2 = [b"loc:@v2"]
+ self.assertEqual(expected_group_v2, v2.op.colocation_groups())
+ for i in v2.initializer.inputs:
+ self.assertEqual(expected_group_v2, i.op.colocation_groups())
+
class IsInitializedTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index 9d4d19668a..c16ba0f814 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -167,19 +167,22 @@ def create_partitioned_variables(
slice_offset[slice_dim] += var_shape[slice_dim]
if callable(initializer):
- init_val = initializer(var_shape, dtype=dtype)
- init_val = ops.convert_to_tensor(init_val, dtype=dtype)
+ init = initializer
+ init_shape = var_shape
elif isinstance(initializer, ops.Tensor):
- init_val = array_ops.slice(initializer, var_offset, var_shape)
+ init = array_ops.slice(initializer, var_offset, var_shape)
# Use the dtype of the given tensor.
- dtype = init_val.dtype.base_dtype
+ dtype = init.dtype.base_dtype
+ init_shape = None
else:
- init_val = ops.convert_to_tensor(initializer, dtype=dtype)
- init_val = array_ops.slice(init_val, var_offset, var_shape)
+ init = ops.convert_to_tensor(initializer, dtype=dtype)
+ init = array_ops.slice(init, var_offset, var_shape)
+ init_shape = None
var = variable_scope.get_variable(name="part_%d" % i,
+ shape=init_shape,
dtype=dtype,
- initializer=init_val,
+ initializer=init,
trainable=trainable,
collections=collections)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 64ad23674b..f60816e022 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -144,14 +144,19 @@ class _VariableStore(object):
with ops.control_dependencies(None):
if initializing_from_value:
init_val = initializer
+ variable_dtype = None
else:
- with ops.name_scope(name + "/Initializer/"):
- init_val = initializer(shape.as_list(), dtype=dtype)
+ init_val = lambda: initializer(shape.as_list(), dtype=dtype)
+ variable_dtype = dtype.base_dtype
# Create the variable.
- v = variables.Variable(init_val, name=name, trainable=trainable,
+ v = variables.Variable(initial_value=init_val,
+ name=name,
+ trainable=trainable,
collections=collections,
- caching_device=caching_device)
+ caching_device=caching_device,
+ dtype=variable_dtype)
+
self._vars[name] = v
logging.info("Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index a21724194b..3f2571bc06 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -156,9 +156,12 @@ class Variable(object):
variable to its initial value.
Args:
- initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
- The initial value for the Variable. Must have a shape specified unless
- `validate_shape` is set to False.
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@@ -211,9 +214,12 @@ class Variable(object):
"""Creates a new variable from arguments.
Args:
- initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
- The initial value for the Variable. Must have a shape specified unless
- `validate_shape` is set to False.
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@@ -240,25 +246,62 @@ class Variable(object):
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+ if init_from_fn and dtype is None:
+ raise ValueError(
+ "dtype must also be specified when initial_value is callable.")
+
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
- with ops.op_scope([initial_value], name, "Variable") as name:
- self._initial_value = ops.convert_to_tensor(initial_value,
- name="initial_value",
- dtype=dtype)
- initial_value_shape = self._initial_value.get_shape()
- if validate_shape and not initial_value_shape.is_fully_defined():
- raise ValueError("initial_value must have a shape specified: %s"
- % self._initial_value)
- shape_to_set = initial_value_shape if validate_shape else []
-
- self._variable = state_ops.variable_op(
- shape_to_set, self._initial_value.dtype.base_dtype,
- set_shape=validate_shape, name=name)
-
+ with ops.op_scope(
+ [] if init_from_fn else [initial_value], name, "Variable") as name:
+
+ # Get the initial value from a callable function. The real shape of the
+ # variable will be set later, since under the init_from_fn case, the
+ # shape won't be known until after the function is invoked.
+ if init_from_fn:
+ self._variable = state_ops.variable_op(
+ [],
+ dtype.base_dtype,
+ set_shape=False,
+ name=name)
+ with ops.colocate_with(self._variable.op):
+ with ops.name_scope("Initializer"):
+ # Colocate the tensors created by the initial_value() function
+ # with the variable itself.
+ self._initial_value = ops.convert_to_tensor(initial_value(),
+ name="initial_value",
+ dtype=dtype)
+
+ # Or get the initial value from a Tensor or Python object.
+ else:
+ self._initial_value = ops.convert_to_tensor(initial_value,
+ name="initial_value",
+ dtype=dtype)
+ # In this case, the variable op can't be created until after the
+ # initial_value has been converted to a Tensor with a known type.
+ self._variable = state_ops.variable_op(
+ [],
+ self._initial_value.dtype.base_dtype,
+ set_shape=False,
+ name=name)
+
+ # Manually overrides the variable's shape with the initial value's.
+ if validate_shape:
+ initial_value_shape = self._initial_value.get_shape()
+ if not initial_value_shape.is_fully_defined():
+ raise ValueError("initial_value must have a shape specified: %s"
+ % self._initial_value)
+ self._variable.set_shape(initial_value_shape)
+ # TODO(b/28152992): Remove the below hack modifying the node_def shape
+ # directly once set_shape() handles it.
+ self._variable.op.node_def.attr["shape"].shape.CopyFrom(
+ initial_value_shape.as_proto())
+
+ # Assigns initial value.
with ops.colocate_with(self._variable.op):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,