aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
blob: 969e1269560e52736d05e6b14ce320d9bd4fcac0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for class MirroredStrategy."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context


class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):

  def _get_distribution_strategy(self):
    return mirrored_strategy.MirroredStrategy(["/device:CPU:0"])

  def testMinimizeLossEager(self):
    self._test_minimize_loss_eager(self._get_distribution_strategy())

  def testMinimizeLossGraph(self):
    self._test_minimize_loss_graph(self._get_distribution_strategy())

  def testMapReduce(self):
    self._test_map_reduce(self._get_distribution_strategy())

  def testDeviceIndex(self):
    self._test_device_index(self._get_distribution_strategy())

  def testTowerId(self):
    self._test_tower_id(self._get_distribution_strategy())

  @test_util.run_in_graph_and_eager_modes
  def testCallAndMergeExceptions(self):
    self._test_call_and_merge_exceptions(self._get_distribution_strategy())


class VariableCreatorStackTest(test.TestCase):

  def testCreatorStacksAreThreadLocal(self):
    devices = ["/device:CPU:0", "/device:GPU:0"]
    dist = mirrored_strategy.MirroredStrategy(devices)

    def model_fn(device_id):
      assert isinstance(device_id, int)

      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribution_strategy_context.get_tower_context().merge_call(
            lambda _: _)
      return v

    def main_thread_creator(next_creator, *args, **kwargs):
      # We are not using the underlying next_creator for test purposes.
      del next_creator, args, kwargs
      return "main_thread"

    with context.graph_mode(), \
        dist.scope(), \
        variable_scope.variable_creator_scope(main_thread_creator):
      result = dist.call_for_each_tower(model_fn, dist.worker_device_index)
      result = dist.unwrap(result)
      expected = ["main_thread:thread_0", "main_thread:thread_1"]
      self.assertEquals(expected, result)


class MultiWorkerMirroredStrategyTest(test.TestCase):

  def testDeviceScope(self):
    """Test the device scope of multi-worker MirroredStrategy."""
    with context.graph_mode():
      strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
      strategy.configure(
          cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]})
      with strategy.scope():
        a = constant_op.constant(1.)
        with ops.device("/cpu:0"):
          b = constant_op.constant(1.)
        self.assertEqual(a.device, "/job:worker/task:0")
        self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")


if __name__ == "__main__":
  test.main()