aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/one_device_strategy.py
blob: f5259190485e701c190beb49220caff743f8fdcb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class OneDeviceStrategy implementing DistributionStrategy."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.contrib.distribute.python import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest


# TODO(josh11b): Replace asserts in this file with if ...: raise ...


class OneDeviceStrategy(distribute_lib.DistributionStrategy):
  """A distribution strategy for running on a single device."""
  # TODO(josh11b): Do we wrap values in types to generate errors if you are
  # doing something that won't work with other DistributionStrategy
  # implementations?

  def __init__(self, device, prefetch_on_device=None):
    super(OneDeviceStrategy, self).__init__()
    self._device = device
    self._prefetch_on_device = prefetch_on_device
    self._default_device = device

  def _create_variable(self, next_creator, *args, **kwargs):
    colocate_with = kwargs.pop("colocate_with", None)
    if colocate_with is None:
      with ops.device(self._device):
        return next_creator(*args, **kwargs)
    if isinstance(colocate_with, six.string_types):
      with ops.device(colocate_with):
        return next_creator(*args, **kwargs)
    if (isinstance(colocate_with, list) and len(colocate_with) == 1 and
        isinstance(colocate_with[0], six.string_types)):
      with ops.device(colocate_with[0]):
        return next_creator(*args, **kwargs)
    with ops.colocate_with(colocate_with):
      return next_creator(*args, **kwargs)

  def distribute_dataset(self, dataset_fn):
    return values.PerDeviceDataset(
        self._call_dataset_fn(dataset_fn), [self._device],
        self._prefetch_on_device)

  def _broadcast(self, tensor, destinations):
    del destinations
    return tensor

  # TODO(priyag): Deal with OutOfRange errors  once b/111349762 is fixed.
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    # TODO(priyag): Use max_iterations instead of an explicit counter.
    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx

  def _call_for_each_tower(self, fn, *args, **kwargs):
    # We don't run `fn` in multiple threads in OneDeviceStrategy.
    kwargs.pop("run_concurrently", None)
    with ops.device(self._device), _OneDeviceTowerContext(self):
      return fn(*args, **kwargs)

  def map(self, map_over, fn, *args, **kwargs):
    with ops.device(self._device):
      return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])

  def _reduce(self, aggregation, value, destinations):
    del destinations
    if not isinstance(value, values.MapOutput):
      return value
    l = value.get()
    assert l
    with ops.device(self._device):
      if aggregation == vs.VariableAggregation.SUM:
        return math_ops.add_n(l)
      elif aggregation == vs.VariableAggregation.MEAN:
        return math_ops.add_n(l) / len(l)
      else:
        assert False

  def _update(self, var, options, fn, *args, **kwargs):
    # The implementations of _update() and _update_non_slot() are identical
    # except _update() passes `var` as the first argument to `fn()`.
    return self._update_non_slot(var, options, fn, var, *args, **kwargs)

  def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
    del colocate_with
    should_group = options.pop("grouped")
    assert not options  # Validate that we are processing all of the options.
    with ops.device(self._device), distribute_lib.UpdateContext(self._device):
      result = fn(*args, **kwargs)
      if should_group:
        return result
      else:
        return nest.map_structure(self._unwrap, result)

  def read_var(self, tower_local_var):
    """Read the aggregate value of a tower-local variable."""
    return array_ops.identity(tower_local_var)

  def _unwrap(self, value):
    return [value]

  def value_container(self, value):
    return value

  @property
  def is_single_tower(self):
    return True

  @property
  def num_towers(self):
    return 1

  @property
  def worker_devices(self):
    return [self._device]

  @property
  def parameter_devices(self):
    return [self._device]

  def non_slot_devices(self, var_list):
    del var_list
    return [self._device]

  def _worker_device_index(self):
    return 0


class _OneDeviceTowerContext(distribute_lib.TowerContext):

  def __init__(self, distribution_strategy):
    distribute_lib.TowerContext.__init__(
        self, distribution_strategy, tower_id=0)

  @property
  def device(self):
    return self._distribution_strategy.worker_devices[0]