aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py23
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py58
2 files changed, 72 insertions, 9 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index a32424b316..0f82508428 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
- l.remove(v)
+ if v in l:
+ l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
@@ -461,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
- if context.executing_eagerly():
- kwargs["initial_value"] = array_ops.identity(
- index[devices[0]].value())
- else:
- def initial_value_fn(device=d):
+ def initial_value_fn(device=d):
+ if context.executing_eagerly():
+ init_value = index[devices[0]].value()
+ return array_ops.identity(init_value)
+ else:
with ops.device(device):
- return array_ops.identity(index[devices[0]].initial_value)
- kwargs["initial_value"] = initial_value_fn
+ init_value = index[devices[0]].initial_value
+ return array_ops.identity(init_value)
+ kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- v = next_creator(*args, **kwargs)
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index eeac528329..ed36639ce8 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import sys
+import numpy as np
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
@@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -1245,6 +1252,22 @@ class MockModel(object):
return x
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name="")
+ self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
+ bias_initializer="ones")
+
+ def call(self, inputs, training=True):
+ inputs = array_ops.ones([1, 10])
+ return self.fc(inputs)
+
+
class MirroredStrategyDefunTest(test.TestCase):
def _skip_eager_if_gpus_less_than(self, num_gpus):
@@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase):
"GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrain(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ cpu_dev = device_util.canonicalize("CPU:0")
+ gpu_dev = device_util.canonicalize("GPU:0")
+ devices = [cpu_dev, gpu_dev]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ with dist.scope():
+ mock_model = MiniModel()
+ mock_model.call = function.defun(mock_model.call)
+
+ def loss_fn(ctx):
+ del ctx
+ return mock_model(array_ops.ones([1, 10]))
+
+ gradients_fn = backprop.implicit_grad(loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+ grads_and_vars = dist.call_for_each_tower(
+ gradients_fn, None, run_concurrently=False)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.25)
+ update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(update_ops)
+
+ updated_var_values = self.evaluate(mock_model.variables)
+ # All variables start at 1.0 and get two updates of 0.25.
+ self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
+ self.assertAllEqual([0.5], updated_var_values[1])
+
+
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,