From 5f69248a692f7b47ea11930621f4f19d0397fe8c Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Tue, 9 Oct 2018 15:07:47 -0700 Subject: Make defun work under distributed strategies. The core of the change is have the gradient tape capture distributed variables instead of plain ResourceVariables. In other words, we move the distribution awareness from defun down to tape and rely on distributed variable magic to provide us with the right variable at runtime. In tower context, we always watch the container (e.g. MirroredVariable). In cross tower context, we always watch all the components. PiperOrigin-RevId: 216430530 --- .../python/mirrored_strategy_multigpu_test.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py') diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index eeac528329..ed36639ce8 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import sys +import numpy as np + from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib @@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl @@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -1245,6 +1252,22 @@ class MockModel(object): return x +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name="") + self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones", + bias_initializer="ones") + + def call(self, inputs, training=True): + inputs = array_ops.ones([1, 10]) + return self.fc(inputs) + + class MirroredStrategyDefunTest(test.TestCase): def _skip_eager_if_gpus_less_than(self, num_gpus): @@ -1365,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase): "GPU:0": 3.0 * 1.25}) self._call_and_check(fn1, [factors], expected_result, [fn1]) + @test_util.run_in_graph_and_eager_modes() + def testTrain(self): + self._skip_eager_if_gpus_less_than(1) + + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MiniModel() + mock_model.call = function.defun(mock_model.call) + + def loss_fn(ctx): + del ctx + return mock_model(array_ops.ones([1, 10])) + + gradients_fn = backprop.implicit_grad(loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + grads_and_vars = dist.call_for_each_tower( + gradients_fn, None, run_concurrently=False) + + optimizer = gradient_descent.GradientDescentOptimizer(0.25) + update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(update_ops) + + updated_var_values = self.evaluate(mock_model.variables) + # All variables start at 1.0 and get two updates of 0.25. + self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0]) + self.assertAllEqual([0.5], updated_var_values[1]) + + class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, -- cgit v1.2.3