aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/distributed_training_utils.py
blob: 050602868a16d282d2ee9707678dbfaf00d684dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities related to distributed training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.client import session as session_module
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest


def set_weights(distribution_strategy, dist_model, weights):
  """Sets the weights of the replicated models.

  The weights of the replicated models are set to the weights of the original
  model. The weights of the replicated model are Mirrored variables and hence
  we need to use the `update` call within a DistributionStrategy scope.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training
        and validation.
    dist_model: The replicated models on the different devices.
    weights: The weights of the original model.
  """
  assign_ops = []
  for layer in dist_model.layers:
    num_param = len(layer.weights)
    layer_weights = weights[:num_param]
    for sw, w in zip(layer.weights, layer_weights):
      assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))

    weights = weights[num_param:]
  K.get_session().run(assign_ops)


def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
                  grouped_updates, grouped_session_args,
                  with_loss_tensor=False):
  """Unwrap and return the list of values contained in the PerDevice parameters.

  This function calls `flatten_perdevice_values` to parse each of the input
  parameters into a list of values on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_inputs: PerDevice inputs returned from the train or test function
        that we ran on each device.
    grouped_outputs: PerDevice outputs returned from the train or test function
        that we ran on each device.
    grouped_updates: PerDevice updates returned from the train or test function
        that we ran on each device.
    grouped_session_args: PerDevice session args returned from the train or
        test function that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerDevice parameters.

  """
  # Unwrap per device values returned from each model's train function.
  # This will be used to construct the main train function.
  all_inputs = flatten_perdevice_values(distribution_strategy,
                                        grouped_inputs)
  if with_loss_tensor:
    # reduce loss tensor before adding it to the list of fetches
    loss = distribution_strategy.unwrap(
        distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
                                     grouped_outputs[0],
                                     destinations='/device:CPU:0'))[0]

    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs[1:])
    all_outputs = [loss] + all_outputs
  else:
    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs)

  all_updates = flatten_perdevice_values(distribution_strategy,
                                         grouped_updates)

  all_session_args = {}
  grouped_feed_dict = grouped_session_args.get('feed_dict')
  if grouped_feed_dict:
    all_session_args['feed_dict'] = flatten_perdevice_values(
        distribution_strategy, grouped_feed_dict)

  grouped_fetches = grouped_session_args.get('fetches')
  if grouped_fetches:
    all_session_args['fetches'] = flatten_perdevice_values(
        distribution_strategy, grouped_fetches)

  return all_inputs, all_outputs, all_updates, all_session_args


def flatten_perdevice_values(distribution_strategy, perdevice_values):
  """Unwraps and flattens a nest of PerDevice parameters.

  PerDevice values have one value associated with each device. Each entry in
  the PerDevice dict has a device `key` and the corresponding value on the
  device as the `value`. In this function we take a PerDevice value or a list of
  PerDevice values and return all the values in the PerDevice dict.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    perdevice_values: List of PerDevice object or a single PerDevice object.

  Returns:
    List of values of all the PerDevice objects.

  """
  # This function takes a PerDevice object or a list of PerDevice objects and
  # returns all the values associated with it.
  return [e for flattened in nest.flatten(perdevice_values)
          for e in distribution_strategy.unwrap(flattened)]


def validate_callbacks(input_callbacks):
  """Validate whether given callbacks are supported by DistributionStrategy.

  Args:
    input_callbacks: List of callbacks passed by the user to fit.

  Raises:
    ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
        callbacks passed.
    ValueError: If `histogram_freq` or `write_grads` is one of the parameters
        passed as part of the TensorBoard callback.
  """
  if input_callbacks:
    for callback in input_callbacks:
      if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
                          callbacks.LearningRateScheduler, callbacks.CSVLogger,
                          callbacks.EarlyStopping, callbacks.ModelCheckpoint,
                          callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
                          callbacks.History, callbacks.RemoteMonitor]:
        logging.warning('Your input callback is not one of the predefined '
                        'Callbacks that supports DistributionStrategy. You '
                        'might encounter an error if you access one of the '
                        'model\'s attributes as part of the callback since '
                        'these attributes are not set. You can access each of '
                        'the individual distributed models using the '
                        '`_grouped_model` attribute of your original model.')
      if isinstance(callback, callbacks.LearningRateScheduler):
        raise ValueError('LearningRateScheduler callback is not supported with '
                         'DistributionStrategy.')
      if isinstance(callback, callbacks.ReduceLROnPlateau):
        raise ValueError('ReduceLROnPlateau callback is not supported with '
                         'DistributionStrategy.')

      # If users want to use the TensorBoard callback they cannot use certain
      # features of the callback that involve accessing model attributes and
      # running ops.
      if isinstance(callback, callbacks.TensorBoard):
        if callback.__getattribute__('histogram_freq'):
          raise ValueError('histogram_freq in the TensorBoard callback is not '
                           'supported when using DistributionStrategy.')
        if callback.__getattribute__('write_grads'):
          raise ValueError('write_grads in the TensorBoard callback is not '
                           'supported when using DistributionStrategy.')


def validate_distributed_dataset_inputs(distribution_strategy, x, y):
  """Validate all the components of a DistributedValue Dataset input.

  Args:
    distribution_strategy: The current DistributionStrategy used to call
        `fit`/`evaluate`.
    x: Input Dataset DistributedValue object. For example, when we use
        `MirroredStrategy` this is a PerDevice object with a tensor for each
        device set in the dict. x can also be a tuple or dict. The keys of the
        dict should match the names of the input layers of the model.
    y: Target Dataset DistributedValue object. For example, when we use
        `MirroredStrategy` this is a PerDevice object with a tensor for each
        device set in the dict. y can also be a tuple or dict. The keys of the
        dict should match the names of the output layers of the model.

  Returns:
    The unwrapped values list of the x and y DistributedValues inputs.

  Raises:
    ValueError: If x and y do not have support for being evaluated as tensors.
        or if x and y contain elements that are not tensors or if x and y
        contain elements that have a shape or dtype mismatch.
  """
  # If the input and target used to call the model are not dataset tensors,
  # we need to raise an error. When using a DistributionStrategy, the input
  # and targets to a model should be from a `tf.data.Dataset`.

  # If each element of x and y are not tensors, we cannot standardize and
  # validate the input and targets.
  x_values_list = validate_per_device_inputs(distribution_strategy, x)

  if y is not None:
    y_values_list = validate_per_device_inputs(distribution_strategy, y)
  else:
    y_values_list = None

  # Return the unwrapped values to avoid calling `unwrap` a second time.
  return x_values_list, y_values_list


def validate_per_device_inputs(distribution_strategy, x):
  """Validates PerDevice dataset input list.

  Args:
    distribution_strategy: The current DistributionStrategy used to call
      `fit`, `evaluate` and `predict`.
    x: A list of PerDevice objects that represent the input or
      target values.

  Returns:
    List containing the first element of each of the PerDevice objects in
    the input list.

  Raises:
    ValueError: If any of the objects in the `per_device_list` is not a tensor.

  """
  # Convert the inputs and targets into a list of PerDevice objects.
  per_device_list = nest.flatten(x)
  x_values_list = []
  for x in per_device_list:
    if not tensor_util.is_tensor(x):
      raise ValueError('Dataset input to the model should be tensors instead '
                       'they are of type {}'.format(type(x)))

    # At this point both x and y contain tensors in the `DistributedValues`
    # structure.
    x_values = distribution_strategy.unwrap(x)

    # Validate that the shape and dtype of all the elements in x are the same.
    validate_all_tensor_shapes(x, x_values)
    validate_all_tensor_types(x, x_values)

    x_values_list.append(x_values[0])
  return x_values_list


def validate_all_tensor_types(x, x_values):
  x_dtype = x_values[0].dtype
  for i in range(1, len(x_values)):
    if x_dtype != x_values[i].dtype:
      raise ValueError('Input tensor dtypes do not match for distributed tensor'
                       ' inputs {}'.format(x))


def validate_all_tensor_shapes(x, x_values):
  # Validate that the shape of all the elements in x have the same shape
  x_shape = x_values[0].get_shape().as_list()
  for i in range(1, len(x_values)):
    if x_shape != x_values[i].get_shape().as_list():
      raise ValueError('Input tensor shapes do not match for distributed tensor'
                       ' inputs {}'.format(x))


def configure_and_create_session(distribution_strategy):
  """Configure session config and create a session with it."""
  # TODO(priyag): Throw error if a session already exists.
  session_config = K.get_default_session_config()
  distribution_strategy.configure(session_config)

  if distribution_strategy.__class__.__name__ == 'TPUStrategy':
    # TODO(priyag): Remove this workaround when Distributed Coordinator is
    # integrated with keras and we can create a session from there.
    master = distribution_strategy._tpu_cluster_resolver.master()  # pylint: disable=protected-access
    session = session_module.Session(config=session_config, target=master)
  else:
    session = session_module.Session(config=session_config)

  K.set_session(session)


def validate_inputs(x, y, distribution_strategy):
  """Validate inputs when using DistributionStrategy.

  Args:
    x: Model Inputs.
    y: Model Targets.
    distribution_strategy: The DistributionStrategy with which the model is
      compiled.

  Raises:
    ValueError: if input is not a Dataset or a numpy array(when we use
      MirroredStrategy).
  """
  if isinstance(x, dict) or isinstance(y, dict):
    raise ValueError('`DistributionStrategy` does not support inputs of type '
                     'dict. You must pass a `tf.data.Dataset` object or a '
                     'numpy array as input.')

  if (isinstance(x, iterator_ops.Iterator) or
      isinstance(y, iterator_ops.Iterator)):
    raise ValueError('`DistributionStrategy` does not support inputs of type '
                     'Iterator. You must pass a `tf.data.Dataset` object or a '
                     'numpy array as input.')

  if distribution_strategy.__class__.__name__ == 'TPUStrategy':
    for i in [x, y]:
      if isinstance(i, dataset_ops.Dataset):
        shapes = nest.flatten(i.output_shapes)
        if any([not s.is_fully_defined() for s in shapes]):
          raise ValueError(
              'Using TPUs currently requires fully defined shapes. Either use '
              'set_shape() on the input tensors or use '
              'dataset.batch(..., drop_remainder=True).'
              'Found unknown shape {} in input {}.'.format(s, i))


def get_input_batch_params(first_x_value, batch_size, distribution_strategy):
  """Calculate the number of batches and steps/steps_per_epoch.

  Args:
    first_x_value: This is the first input numpy array that is passed in as the
      model input.
    batch_size: The specified batch_size or the default batch_size of 32.
    distribution_strategy: The current DistributionStrategy used to compile the
      model.

  Returns:
    The steps or steps_per_epoch argument depending on if a user is
    calling `fit`, `evaluate` or `predict`.

  Raises:
    ValueError: If the number of batches or steps evaluates to 0.

  """
  num_batches = first_x_value.shape[0] // batch_size
  if not num_batches:
    raise ValueError('Please specify a batch_size that is smaller than'
                     'the number of input samples %d.' % first_x_value.shape[0])
  # TODO(anjalisridhar): TPU currently supports using the num_towers property.
  # We might want to look into implementing worker_devices. In multi worker
  # strategy, perhaps num_towers works better?
  steps = num_batches // distribution_strategy.num_towers
  if not steps:
    # TODO(anjalisridhar): Number of towers in the error message may not convey
    # what we want to the user. Is there another terminology that we can use
    # that is consistent across different strategies.
    raise ValueError('The number of batches %d is smaller than the number '
                     'of towers %d used for DistributionStrategy. ' %
                     (num_batches, distribution_strategy.num_towers))
  return steps


def get_batch_dimension(iterator):
  shapes = nest.flatten(iterator.output_shapes)
  # Take the batch size from the first element, as it should be the same for
  # all.
  dims = shapes[0].dims
  return dims[0] if dims else None


def get_cpu_device(distribution_strategy):
  """Returns the CPU device of the TPU host or the default CPU device string.

  Args:
    distribution_strategy: The DistributionStrategy used to compile the model.

  Returns:
    A device string which is the TPU host's CPU device in case of
    TPUDistributionStrategy or the default CPU device string in all other
    cases.

  Raises:
    NotImplementedError: We currently don't support copying numpy data to
    multiple hosts in the case of Cloud TPU pods.
  """
  if distribution_strategy.__class__.__name__ == 'TPUStrategy':
    if distribution_strategy.num_hosts > 1:
      raise NotImplementedError('TPUDistributionStrategy does not '
                                'support numpy inputs when running on Cloud'
                                'TPU pods.')
    return distribution_strategy.get_host_cpu_device(0)
  else:
    # For all strategies except TPUDistributionStrategy
    # TODO(anjalisridhar): We may need to modify this when we add support for
    # multi-worker strategy.
    return '/CPU:0'


def get_var_for_numpy(distribution_strategy, x):
  if isinstance(x, list):
    var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input)
                   for single_input in x])
  else:
    var_x = _get_var_for_numpy(distribution_strategy, x)
  return var_x


def _get_var_for_numpy(distribution_strategy, input_array):
  """Creates a variable and assigns the value of the numpy array to it.

  Args:
    distribution_strategy: The DistributionStrategy used to compile the model.
    input_array: The input numpy array whose value will be assigned to the
      variable we create.

  Returns:
    The variable to which we will copy the value of the input numpy array.

  """
  with ops.device(get_cpu_device(distribution_strategy)):
    # Create and initialize a variable on the CPU device. This is the CPU
    # device of the host in the case of TPUDistributionStrategy.
    input_var = variables.VariableV1(array_ops.zeros(input_array.shape,
                                                     input_array.dtype),
                                     trainable=False, use_resource=True)
  K.get_session().run(input_var.initializer)

  # Create a placeholder for the numpy array input slices. We copy the value
  # of the input numpy array to the variable in slices of size 64 MB to avoid
  # running into memory issues or RPC message limits.
  start_placeholder = array_ops.placeholder(dtypes.int64, ())
  end_placeholder = array_ops.placeholder(dtypes.int64, ())
  slice_placeholder = array_ops.placeholder(input_var.dtype)
  assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
      slice_placeholder)

  # If each batch element is > 64 MB, then we copy each batch element
  # individually. Otherwise, the slices will be < 128 MB. There might be padding
  # which might mean that the slices are 128 MB even if the size of the
  # tensor allocated is less than 128 MB.
  # This formula gives slices with size:
  # ceil(64 MB / byte size per batch element) bytes.
  # Using ceil() guarantees we get a number >= 1.

  # Calculate the size of each batch element.
  byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \
                                input_var.dtype.size

  # Calculate number of elements we want to copy per slice.
  batch_size_per_slice = np.ceil((64 << 20) / byte_size_per_batch_element)

  # Copy slices of the above size starting at 0, except the last slice will be
  # smaller.
  start = 0
  limit = input_array.shape[0]
  while start < limit:
    end = min(start + batch_size_per_slice, limit)
    K.get_session().run(assign_slice_op, feed_dict={
        start_placeholder: start,
        end_placeholder: end,
        slice_placeholder: input_array[start:end]})
    start = end

  return input_var