aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 09:28:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 09:31:09 -0700
commit4ec3fcdc87687d33c1597aff9296041a6bb00434 (patch)
tree6ea9f6cabbe0a06f6ac13ac0dfe6b14f3ed32037 /tensorflow/contrib/framework
parentb704ab9e65a3e44568e91eeded277fdd1b072508 (diff)
Adds support for explicitly assigning the replica to the VariableDeviceChooser. This is necessary for when the device with replica is set in a surrounding arg_scope.
PiperOrigin-RevId: 200567897
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py10
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py120
2 files changed, 83 insertions, 47 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 40ae01bfcc..e8e3180019 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -712,7 +712,8 @@ class VariableDeviceChooser(object):
num_tasks=0,
job_name='ps',
device_type='CPU',
- device_index=0):
+ device_index=0,
+ replica=None):
"""Initialize VariableDeviceChooser.
Usage:
@@ -733,12 +734,15 @@ class VariableDeviceChooser(object):
self._job_name = job_name
self._device_type = device_type
self._device_index = device_index
+ self._replica = replica
self._num_tasks = num_tasks
self._next_task_id = 0
def __call__(self, op):
- device_spec = tf_device.DeviceSpec(device_type=self._device_type,
- device_index=self._device_index)
+ device_spec = tf_device.DeviceSpec(
+ replica=self._replica,
+ device_type=self._device_type,
+ device_index=self._device_index)
if self._num_tasks > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 37ea6eb12a..7e0c7dbec1 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -506,6 +506,35 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+ def testVariableWithVariableDeviceChooserWithReplica(self):
+
+ with ops.Graph().as_default():
+ device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2)
+ with arg_scope([variables_lib2.variable], device=device_fn):
+ a = variables_lib2.variable('a', [])
+ b = variables_lib2.variable('b', [])
+ c = variables_lib2.variable('c', [], device='cpu:12')
+ d = variables_lib2.variable('d', [])
+ with ops.device('cpu:99'):
+ e_init = constant_op.constant(12)
+ e = variables_lib2.variable('e', initializer=e_init)
+ # The values below highlight how the VariableDeviceChooser puts initial
+ # values on the same device as the variable job.
+ self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(a.initial_value.op.colocation_groups(),
+ a.op.colocation_groups())
+ self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertEqual(b.initial_value.op.colocation_groups(),
+ b.op.colocation_groups())
+ self.assertDeviceEqual(c.device, '/cpu:12')
+ self.assertEqual(c.initial_value.op.colocation_groups(),
+ c.op.colocation_groups())
+ self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(d.initial_value.op.colocation_groups(),
+ d.op.colocation_groups())
+ self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
def testVariableGPUPlacement(self):
with ops.Graph().as_default():
@@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
init_value0 = 10.0
init_value1 = 20.0
@@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase):
# Tests restoring PartitionedVariables and tests using a dictionary
# of lists as the assign_from_checkpoint() var_list param.
def testLoadPartitionedVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_partitioned_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables'))
init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]])
init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case.
@@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase):
partitioner = partitioned_variables.variable_axis_size_partitioner(2)
var0 = variables_lib2.variable(
'var0', shape=init_value0.shape, partitioner=partitioner)
- var0full = variables_lib2.variable(
- 'var0full', shape=init_value0.shape)
+ var0full = variables_lib2.variable('var0full', shape=init_value0.shape)
var1 = variables_lib2.variable(
'var1', shape=init_value1.shape, partitioner=partitioner)
# Convert var0 and var1 into a list of underlying variables.
vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase):
# Request and test the variable values. PartitionedVariables can't
# be evaled so we wrap them in an identity.
- self.assertTrue(np.array_equal(
- init_value0, array_ops.identity(var0).eval()))
- self.assertTrue(np.array_equal(
- init_value0, var0full.eval()))
- self.assertTrue(np.array_equal(
- init_value1, array_ops.identity(var1).eval()))
+ self.assertTrue(
+ np.array_equal(init_value0,
+ array_ops.identity(var0).eval()))
+ self.assertTrue(np.array_equal(init_value0, var0full.eval()))
+ self.assertTrue(
+ np.array_equal(init_value1,
+ array_ops.identity(var1).eval()))
def testRaisesValueErrorIfAVariableIsntFound(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'raises_value_error_if_var_isnt_found'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'raises_value_error_if_var_isnt_found'))
init_value0 = 10.0
init_value1 = 20.0
@@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase):
variables_lib2.assign_from_checkpoint(model_path, vars_to_restore)
def testInitFromCheckpointWithScopes(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'init_from_checkpoint_with_scopes'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'init_from_checkpoint_with_scopes'))
init_value0 = np.asarray(
[1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1))
@@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=init_value1.shape)
vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_existing_vars_no_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'load_existing_vars_no_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testLoadExistingVariablesDifferentShapeAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(),
- 'load_existing_variables_different_shape_allow_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(
+ self.get_temp_dir(),
+ 'load_existing_variables_different_shape_allow_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testNotFoundError(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'not_found_error'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'not_found_error'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var2 = variables_lib2.variable('my_var2', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testMissingVariablesList(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_list'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testMissingVariablesDict(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_dict'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase):
def testZeroInitializer(self):
for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64):
for use_init in (False, True):
- self._testZeroInitializer(
- [10, 20], array_ops.ones(
- [10, 20], dtype=dtype), use_init)
+ self._testZeroInitializer([10, 20], array_ops.ones(
+ [10, 20], dtype=dtype), use_init)
class ZeroVarInitializerOpTest(test.TestCase):