diff options
author | 2018-06-14 09:28:17 -0700 | |
---|---|---|
committer | 2018-06-14 09:31:09 -0700 | |
commit | 4ec3fcdc87687d33c1597aff9296041a6bb00434 (patch) | |
tree | 6ea9f6cabbe0a06f6ac13ac0dfe6b14f3ed32037 /tensorflow/contrib/framework | |
parent | b704ab9e65a3e44568e91eeded277fdd1b072508 (diff) |
Adds support for explicitly assigning the replica to the VariableDeviceChooser. This is necessary for when the device with replica is set in a surrounding arg_scope.
PiperOrigin-RevId: 200567897
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables_test.py | 120 |
2 files changed, 83 insertions, 47 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 40ae01bfcc..e8e3180019 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -712,7 +712,8 @@ class VariableDeviceChooser(object): num_tasks=0, job_name='ps', device_type='CPU', - device_index=0): + device_index=0, + replica=None): """Initialize VariableDeviceChooser. Usage: @@ -733,12 +734,15 @@ class VariableDeviceChooser(object): self._job_name = job_name self._device_type = device_type self._device_index = device_index + self._replica = replica self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(device_type=self._device_type, - device_index=self._device_index) + device_spec = tf_device.DeviceSpec( + replica=self._replica, + device_type=self._device_type, + device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 37ea6eb12a..7e0c7dbec1 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -506,6 +506,35 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableWithVariableDeviceChooserWithReplica(self): + + with ops.Graph().as_default(): + device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2) + with arg_scope([variables_lib2.variable], device=device_fn): + a = variables_lib2.variable('a', []) + b = variables_lib2.variable('b', []) + c = variables_lib2.variable('c', [], device='cpu:12') + d = variables_lib2.variable('d', []) + with ops.device('cpu:99'): + e_init = constant_op.constant(12) + e = variables_lib2.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(a.initial_value.op.colocation_groups(), + a.op.colocation_groups()) + self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertEqual(b.initial_value.op.colocation_groups(), + b.op.colocation_groups()) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertEqual(c.initial_value.op.colocation_groups(), + c.op.colocation_groups()) + self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(d.initial_value.op.colocation_groups(), + d.op.colocation_groups()) + self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableGPUPlacement(self): with ops.Graph().as_default(): @@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 @@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase): # Tests restoring PartitionedVariables and tests using a dictionary # of lists as the assign_from_checkpoint() var_list param. def testLoadPartitionedVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_partitioned_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables')) init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]]) init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. @@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase): partitioner = partitioned_variables.variable_axis_size_partitioner(2) var0 = variables_lib2.variable( 'var0', shape=init_value0.shape, partitioner=partitioner) - var0full = variables_lib2.variable( - 'var0full', shape=init_value0.shape) + var0full = variables_lib2.variable('var0full', shape=init_value0.shape) var1 = variables_lib2.variable( 'var1', shape=init_value1.shape, partitioner=partitioner) # Convert var0 and var1 into a list of underlying variables. vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase): # Request and test the variable values. PartitionedVariables can't # be evaled so we wrap them in an identity. - self.assertTrue(np.array_equal( - init_value0, array_ops.identity(var0).eval())) - self.assertTrue(np.array_equal( - init_value0, var0full.eval())) - self.assertTrue(np.array_equal( - init_value1, array_ops.identity(var1).eval())) + self.assertTrue( + np.array_equal(init_value0, + array_ops.identity(var0).eval())) + self.assertTrue(np.array_equal(init_value0, var0full.eval())) + self.assertTrue( + np.array_equal(init_value1, + array_ops.identity(var1).eval())) def testRaisesValueErrorIfAVariableIsntFound(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 @@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) def testInitFromCheckpointWithScopes(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) @@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_existing_vars_no_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testLoadExistingVariablesDifferentShapeAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), - 'load_existing_variables_different_shape_allow_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join( + self.get_temp_dir(), + 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testNotFoundError(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'not_found_error')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testMissingVariablesList(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_list')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testMissingVariablesDict(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_dict')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase): def testZeroInitializer(self): for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): for use_init in (False, True): - self._testZeroInitializer( - [10, 20], array_ops.ones( - [10, 20], dtype=dtype), use_init) + self._testZeroInitializer([10, 20], array_ops.ones( + [10, 20], dtype=dtype), use_init) class ZeroVarInitializerOpTest(test.TestCase): |