aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-10-09 15:07:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 15:12:12 -0700
commit5f69248a692f7b47ea11930621f4f19d0397fe8c (patch)
treee6ae69c17d798afc96ba83644bf2ce6656181856 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parentc1093a3757224257fed0f7a1959d0fc99d5c757f (diff)
Make defun work under distributed strategies.
The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index eeac528329..ed36639ce8 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import sys
+import numpy as np
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
@@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -1245,6 +1252,22 @@ class MockModel(object):
return x
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name="")
+ self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
+ bias_initializer="ones")
+
+ def call(self, inputs, training=True):
+ inputs = array_ops.ones([1, 10])
+ return self.fc(inputs)
+
+
class MirroredStrategyDefunTest(test.TestCase):
def _skip_eager_if_gpus_less_than(self, num_gpus):
@@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase):
"GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrain(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ cpu_dev = device_util.canonicalize("CPU:0")
+ gpu_dev = device_util.canonicalize("GPU:0")
+ devices = [cpu_dev, gpu_dev]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ with dist.scope():
+ mock_model = MiniModel()
+ mock_model.call = function.defun(mock_model.call)
+
+ def loss_fn(ctx):
+ del ctx
+ return mock_model(array_ops.ones([1, 10]))
+
+ gradients_fn = backprop.implicit_grad(loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+ grads_and_vars = dist.call_for_each_tower(
+ gradients_fn, None, run_concurrently=False)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.25)
+ update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(update_ops)
+
+ updated_var_values = self.evaluate(mock_model.variables)
+ # All variables start at 1.0 and get two updates of 0.25.
+ self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
+ self.assertAllEqual([0.5], updated_var_values[1])
+
+
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,