aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index eeac528329..ed36639ce8 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import sys
+import numpy as np
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
@@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -1245,6 +1252,22 @@ class MockModel(object):
return x
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name="")
+ self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
+ bias_initializer="ones")
+
+ def call(self, inputs, training=True):
+ inputs = array_ops.ones([1, 10])
+ return self.fc(inputs)
+
+
class MirroredStrategyDefunTest(test.TestCase):
def _skip_eager_if_gpus_less_than(self, num_gpus):
@@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase):
"GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
+ @test_util.run_in_graph_and_eager_modes()
+ def testTrain(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ cpu_dev = device_util.canonicalize("CPU:0")
+ gpu_dev = device_util.canonicalize("GPU:0")
+ devices = [cpu_dev, gpu_dev]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+
+ with dist.scope():
+ mock_model = MiniModel()
+ mock_model.call = function.defun(mock_model.call)
+
+ def loss_fn(ctx):
+ del ctx
+ return mock_model(array_ops.ones([1, 10]))
+
+ gradients_fn = backprop.implicit_grad(loss_fn)
+ gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
+ grads_and_vars = dist.call_for_each_tower(
+ gradients_fn, None, run_concurrently=False)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.25)
+ update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(update_ops)
+
+ updated_var_values = self.evaluate(mock_model.variables)
+ # All variables start at 1.0 and get two updates of 0.25.
+ self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
+ self.assertAllEqual([0.5], updated_var_values[1])
+
+
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,