diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-09-22 09:11:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-22 09:15:41 -0700 |
commit | e692dda4c8b199555e2fa32132a7784e0893c870 (patch) | |
tree | 72326d46432ab785b5ab4e978f492b69e3ec59e8 /tensorflow/contrib/distribute | |
parent | ca552d54ac67be8837aeabdb43269846d9df4eb5 (diff) |
Fixed a bug in CollectiveAllReduce that sometimes the variable names it sees are not complete and thus not unique, leading to same collective keys for different variables.
PiperOrigin-RevId: 214117466
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py | 78 |
2 files changed, 85 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 77079d0df9..297cacf192 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" index = {} + unique_var_name = ops.get_default_graph().unique_name( + kwargs["name"], mark_as_used=False).rstrip("/") collective_instance_key = self._collective_keys.get_instance_key( - key_id=kwargs["name"]) + key_id=unique_var_name) if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] @@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) + if i == 0: + actual_var_name = v.name.split(":")[0] + assert unique_var_name == actual_var_name, "%r vs %r" % ( + unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) index[d] = v return index diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 36e9761073..33ffbf6abe 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +35,14 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import training_util class CollectiveAllReduceStrategyTestBase( @@ -146,6 +152,56 @@ class CollectiveAllReduceStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_complex_model(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + + def model_fn(): + """Mnist model with synthetic input.""" + data_format = 'channels_last' + input_shape = [28, 28, 1] + l = keras.layers + max_pool = l.MaxPooling2D((2, 2), (2, 2), + padding='same', + data_format=data_format) + model = keras.Sequential([ + l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)), + l.Conv2D( + 32, + 5, + padding='same', + data_format=data_format, + activation=nn.relu), max_pool, + l.Conv2D( + 64, + 5, + padding='same', + data_format=data_format, + activation=nn.relu), max_pool, + l.Flatten(), + l.Dense(1024, activation=nn.relu), + l.Dropout(0.4), + l.Dense(10) + ]) + image = random_ops.random_uniform([2, 28, 28]) + label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) + logits = model(image, training=True) + loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) + optimizer = adam.AdamOptimizer(learning_rate=1e-4) + train_op = optimizer.minimize(loss, + training_util.get_or_create_global_step()) + return train_op + + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess: + with d.scope(): + train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + return True + def _test_variable_initialization(self, task_type, task_id, num_gpus): distribution, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -206,6 +262,14 @@ class DistributedCollectiveAllReduceStrategyTest( self._cluster_spec, num_gpus=num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testComplexModel(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -236,6 +300,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._cluster_spec, num_gpus=num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testComplexModel(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + class LocalCollectiveAllReduceStrategy( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -246,6 +318,12 @@ class LocalCollectiveAllReduceStrategy( return self._test_minimize_loss_graph(None, None, num_gpus) + def testComplexModel(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + self._test_complex_model(None, None, num_gpus) + if __name__ == '__main__': test.main() |