aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Wei Ho <weiho4+github@gmail.com>2016-04-12 16:28:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-12 17:32:00 -0700
commit8b34de66cf15b97ef5bb44ce6217c760c63aadc2 (patch)
tree7c6bc8adddb94c53e861c3f918d8a84acfc075cb
parentc052df404bc19e00f8b7592f120987e8925d16bb (diff)
Adds option to pass callable initializer function to Variable constructor to allow colocation of variable initialization with the device the variable is on, instead of always being on the chief supervisor.
Also updates variable_scope.get_variable() and create_partitioned_variables() to take advantage of this when an initializer fn is passed in. Change: 119697860
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py47
-rw-r--r--tensorflow/python/ops/partitioned_variables.py17
-rw-r--r--tensorflow/python/ops/variable_scope.py13
-rw-r--r--tensorflow/python/ops/variables.py83
4 files changed, 129 insertions, 31 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 593cd5f25a..f2325779c1 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase):
self.assertEqual(var.op.device, init_op.device)
sess.run(init_op)
+ def testInitializerFunction(self):
+ value = [[-42], [133.7]]
+ shape = [2, 1]
+ with self.test_session():
+ initializer = lambda: tf.constant(value)
+ with self.assertRaises(ValueError):
+ # Checks that dtype must be specified.
+ tf.Variable(initializer)
+
+ v1 = tf.Variable(initializer, dtype=tf.float32)
+ self.assertEqual(shape, v1.get_shape())
+ self.assertAllClose(value, v1.initial_value.eval())
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ v1.eval()
+
+ v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32)
+ self.assertEqual(v1.get_shape(), v2.get_shape())
+ self.assertAllClose(np.negative(value), v2.initial_value.eval())
+
+ # Once v2.initial_value.eval() has been called, v1 has effectively been
+ # initialized.
+ self.assertAllClose(value, v1.eval())
+
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ v2.eval()
+ tf.initialize_all_variables().run()
+ self.assertAllClose(np.negative(value), v2.eval())
+
+ def testInitializerFunctionDevicePlacement(self):
+ with self.test_session():
+ initializer = lambda: tf.constant(42.0)
+ with tf.device("/cpu:100"):
+ v1 = tf.Variable(initializer, dtype=tf.float32, name="v1")
+ expected_device = "/device:CPU:100"
+ expected_group_v1 = [b"loc:@v1"]
+ self.assertEqual(expected_device, v1.op.device)
+ self.assertEqual(expected_group_v1, v1.op.colocation_groups())
+ for i in v1.initializer.inputs:
+ self.assertEqual(expected_device, i.op.device)
+ self.assertEqual(expected_group_v1, i.op.colocation_groups())
+
+ v2 = tf.Variable(initializer, dtype=tf.float32, name="v2")
+ expected_group_v2 = [b"loc:@v2"]
+ self.assertEqual(expected_group_v2, v2.op.colocation_groups())
+ for i in v2.initializer.inputs:
+ self.assertEqual(expected_group_v2, i.op.colocation_groups())
+
class IsInitializedTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index 9d4d19668a..c16ba0f814 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -167,19 +167,22 @@ def create_partitioned_variables(
slice_offset[slice_dim] += var_shape[slice_dim]
if callable(initializer):
- init_val = initializer(var_shape, dtype=dtype)
- init_val = ops.convert_to_tensor(init_val, dtype=dtype)
+ init = initializer
+ init_shape = var_shape
elif isinstance(initializer, ops.Tensor):
- init_val = array_ops.slice(initializer, var_offset, var_shape)
+ init = array_ops.slice(initializer, var_offset, var_shape)
# Use the dtype of the given tensor.
- dtype = init_val.dtype.base_dtype
+ dtype = init.dtype.base_dtype
+ init_shape = None
else:
- init_val = ops.convert_to_tensor(initializer, dtype=dtype)
- init_val = array_ops.slice(init_val, var_offset, var_shape)
+ init = ops.convert_to_tensor(initializer, dtype=dtype)
+ init = array_ops.slice(init, var_offset, var_shape)
+ init_shape = None
var = variable_scope.get_variable(name="part_%d" % i,
+ shape=init_shape,
dtype=dtype,
- initializer=init_val,
+ initializer=init,
trainable=trainable,
collections=collections)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 64ad23674b..f60816e022 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -144,14 +144,19 @@ class _VariableStore(object):
with ops.control_dependencies(None):
if initializing_from_value:
init_val = initializer
+ variable_dtype = None
else:
- with ops.name_scope(name + "/Initializer/"):
- init_val = initializer(shape.as_list(), dtype=dtype)
+ init_val = lambda: initializer(shape.as_list(), dtype=dtype)
+ variable_dtype = dtype.base_dtype
# Create the variable.
- v = variables.Variable(init_val, name=name, trainable=trainable,
+ v = variables.Variable(initial_value=init_val,
+ name=name,
+ trainable=trainable,
collections=collections,
- caching_device=caching_device)
+ caching_device=caching_device,
+ dtype=variable_dtype)
+
self._vars[name] = v
logging.info("Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index a21724194b..3f2571bc06 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -156,9 +156,12 @@ class Variable(object):
variable to its initial value.
Args:
- initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
- The initial value for the Variable. Must have a shape specified unless
- `validate_shape` is set to False.
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@@ -211,9 +214,12 @@ class Variable(object):
"""Creates a new variable from arguments.
Args:
- initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
- The initial value for the Variable. Must have a shape specified unless
- `validate_shape` is set to False.
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@@ -240,25 +246,62 @@ class Variable(object):
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+ if init_from_fn and dtype is None:
+ raise ValueError(
+ "dtype must also be specified when initial_value is callable.")
+
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
- with ops.op_scope([initial_value], name, "Variable") as name:
- self._initial_value = ops.convert_to_tensor(initial_value,
- name="initial_value",
- dtype=dtype)
- initial_value_shape = self._initial_value.get_shape()
- if validate_shape and not initial_value_shape.is_fully_defined():
- raise ValueError("initial_value must have a shape specified: %s"
- % self._initial_value)
- shape_to_set = initial_value_shape if validate_shape else []
-
- self._variable = state_ops.variable_op(
- shape_to_set, self._initial_value.dtype.base_dtype,
- set_shape=validate_shape, name=name)
-
+ with ops.op_scope(
+ [] if init_from_fn else [initial_value], name, "Variable") as name:
+
+ # Get the initial value from a callable function. The real shape of the
+ # variable will be set later, since under the init_from_fn case, the
+ # shape won't be known until after the function is invoked.
+ if init_from_fn:
+ self._variable = state_ops.variable_op(
+ [],
+ dtype.base_dtype,
+ set_shape=False,
+ name=name)
+ with ops.colocate_with(self._variable.op):
+ with ops.name_scope("Initializer"):
+ # Colocate the tensors created by the initial_value() function
+ # with the variable itself.
+ self._initial_value = ops.convert_to_tensor(initial_value(),
+ name="initial_value",
+ dtype=dtype)
+
+ # Or get the initial value from a Tensor or Python object.
+ else:
+ self._initial_value = ops.convert_to_tensor(initial_value,
+ name="initial_value",
+ dtype=dtype)
+ # In this case, the variable op can't be created until after the
+ # initial_value has been converted to a Tensor with a known type.
+ self._variable = state_ops.variable_op(
+ [],
+ self._initial_value.dtype.base_dtype,
+ set_shape=False,
+ name=name)
+
+ # Manually overrides the variable's shape with the initial value's.
+ if validate_shape:
+ initial_value_shape = self._initial_value.get_shape()
+ if not initial_value_shape.is_fully_defined():
+ raise ValueError("initial_value must have a shape specified: %s"
+ % self._initial_value)
+ self._variable.set_shape(initial_value_shape)
+ # TODO(b/28152992): Remove the below hack modifying the node_def shape
+ # directly once set_shape() handles it.
+ self._variable.op.node_def.attr["shape"].shape.CopyFrom(
+ initial_value_shape.as_proto())
+
+ # Assigns initial value.
with ops.colocate_with(self._variable.op):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,