aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-22 09:11:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-22 09:15:41 -0700
commite692dda4c8b199555e2fa32132a7784e0893c870 (patch)
tree72326d46432ab785b5ab4e978f492b69e3ec59e8 /tensorflow/contrib/distribute
parentca552d54ac67be8837aeabdb43269846d9df4eb5 (diff)
Fixed a bug in CollectiveAllReduce that sometimes the variable names it sees are not complete and thus not unique, leading to same collective keys for different variables.
PiperOrigin-RevId: 214117466
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py8
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py78
2 files changed, 85 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 77079d0df9..297cacf192 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
index = {}
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
collective_instance_key = self._collective_keys.get_instance_key(
- key_id=kwargs["name"])
+ key_id=unique_var_name)
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
@@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 36e9761073..33ffbf6abe 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -34,9 +35,14 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import training_util
class CollectiveAllReduceStrategyTestBase(
@@ -146,6 +152,56 @@ class CollectiveAllReduceStrategyTestBase(
self.assertLess(error_after, error_before)
return error_after < error_before
+ def _test_complex_model(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+
+ def model_fn():
+ """Mnist model with synthetic input."""
+ data_format = 'channels_last'
+ input_shape = [28, 28, 1]
+ l = keras.layers
+ max_pool = l.MaxPooling2D((2, 2), (2, 2),
+ padding='same',
+ data_format=data_format)
+ model = keras.Sequential([
+ l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
+ l.Conv2D(
+ 32,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Conv2D(
+ 64,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Flatten(),
+ l.Dense(1024, activation=nn.relu),
+ l.Dropout(0.4),
+ l.Dense(10)
+ ])
+ image = random_ops.random_uniform([2, 28, 28])
+ label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
+ logits = model(image, training=True)
+ loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
+ optimizer = adam.AdamOptimizer(learning_rate=1e-4)
+ train_op = optimizer.minimize(loss,
+ training_util.get_or_create_global_step())
+ return train_op
+
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess:
+ with d.scope():
+ train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(d.unwrap(train_op))
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ return True
+
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -206,6 +262,14 @@ class DistributedCollectiveAllReduceStrategyTest(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -236,6 +300,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class LocalCollectiveAllReduceStrategy(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -246,6 +318,12 @@ class LocalCollectiveAllReduceStrategy(
return
self._test_minimize_loss_graph(None, None, num_gpus)
+ def testComplexModel(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_complex_model(None, None, num_gpus)
+
if __name__ == '__main__':
test.main()