diff options
Diffstat (limited to 'tensorflow/python/keras/engine/distributed_training_utils.py')
-rw-r--r-- | tensorflow/python/keras/engine/distributed_training_utils.py | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py new file mode 100644 index 0000000000..c78e6fe9ec --- /dev/null +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -0,0 +1,249 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to distributed training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import backend +from tensorflow.python.keras import callbacks +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest + + +def set_weights(distribution_strategy, dist_model, weights): + """Sets the weights of the replicated models. + + The weights of the replicated models are set to the weights of the original + model. The weights of the replicated model are Mirrored variables and hence + we need to use the `update` call within a DistributionStrategy scope. + + Args: + distribution_strategy: DistributionStrategy used to distribute training + and validation. + dist_model: The replicated models on the different devices. + weights: The weights of the original model. + """ + assign_ops = [] + for layer in dist_model.layers: + num_param = len(layer.weights) + layer_weights = weights[:num_param] + for sw, w in zip(layer.weights, layer_weights): + assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) + + weights = weights[num_param:] + backend.get_session().run(assign_ops) + + +def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, + grouped_updates, grouped_session_args, + with_loss_tensor=False): + """Unwrap and return the list of values contained in the PerDevice parameters. + + This function calls `flatten_perdevice_values` to parse each of the input + parameters into a list of values on the different devices. If we set + `with_loss_tensor` to be True, we also call `reduce` on the list of losses on + the different devices to give us one loss tensor. + + Args: + distribution_strategy: DistributionStrategy used to distribute training and + validation. + grouped_inputs: PerDevice inputs returned from the train or test function + that we ran on each device. + grouped_outputs: PerDevice outputs returned from the train or test function + that we ran on each device. + grouped_updates: PerDevice updates returned from the train or test function + that we ran on each device. + grouped_session_args: PerDevice session args returned from the train or + test function that we ran on each device. + with_loss_tensor: Boolean that indicates if we need to add the reduced loss + tensor as one of the outputs. + + Returns: + Values of each of the PerDevice parameters. + + """ + # Unwrap per device values returned from each model's train function. + # This will be used to construct the main train function. + all_inputs = flatten_perdevice_values(distribution_strategy, + grouped_inputs) + if with_loss_tensor: + # reduce loss tensor before adding it to the list of fetches + loss = distribution_strategy.unwrap( + distribution_strategy.reduce(distribute_lib.get_loss_reduction(), + grouped_outputs[0], + destinations='/device:CPU:0'))[0] + + all_outputs = flatten_perdevice_values(distribution_strategy, + grouped_outputs[1:]) + all_outputs = [loss] + all_outputs + else: + all_outputs = flatten_perdevice_values(distribution_strategy, + grouped_outputs) + + all_updates = flatten_perdevice_values(distribution_strategy, + grouped_updates) + + all_session_args = {} + grouped_feed_dict = grouped_session_args.get('feed_dict') + if grouped_feed_dict: + all_session_args['feed_dict'] = flatten_perdevice_values( + distribution_strategy, grouped_feed_dict) + + grouped_fetches = grouped_session_args.get('fetches') + if grouped_fetches: + all_session_args['fetches'] = flatten_perdevice_values( + distribution_strategy, grouped_fetches) + + return all_inputs, all_outputs, all_updates, all_session_args + + +def flatten_perdevice_values(distribution_strategy, perdevice_values): + """Unwraps and flattens a nest of PerDevice parameters. + + PerDevice values have one value associated with each device. Each entry in + the PerDevice dict has a device `key` and the corresponding value on the + device as the `value`. In this function we take a PerDevice value or a list of + PerDevice values and return all the values in the PerDevice dict. + + Args: + distribution_strategy: DistributionStrategy used to distribute training and + validation. + perdevice_values: List of PerDevice object or a single PerDevice object. + + Returns: + List of values of all the PerDevice objects. + + """ + # This function takes a PerDevice object or a list of PerDevice objects and + # returns all the values associated with it. + return [e for flattened in nest.flatten(perdevice_values) + for e in distribution_strategy.unwrap(flattened)] + + +def validate_callbacks(input_callbacks): + """Validate whether given callbacks are supported by DistributionStrategy. + + Args: + input_callbacks: List of callbacks passed by the user to fit. + + Raises: + ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the + callbacks passed. + ValueError: If `histogram_freq` or `write_grads` is one of the parameters + passed as part of the TensorBoard callback. + """ + if input_callbacks: + for callback in input_callbacks: + if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau, + callbacks.LearningRateScheduler, callbacks.CSVLogger, + callbacks.EarlyStopping, callbacks.ModelCheckpoint, + callbacks.TerminateOnNaN, callbacks.ProgbarLogger, + callbacks.History, callbacks.RemoteMonitor]: + logging.warning('Your input callback is not one of the predefined ' + 'Callbacks that supports DistributionStrategy. You ' + 'might encounter an error if you access one of the ' + 'model\'s attributes as part of the callback since ' + 'these attributes are not set. You can access each of ' + 'the individual distributed models using the ' + '`_grouped_model` attribute of your original model.') + if isinstance(callback, callbacks.LearningRateScheduler): + raise ValueError('LearningRateScheduler callback is not supported with ' + 'DistributionStrategy.') + if isinstance(callback, callbacks.ReduceLROnPlateau): + raise ValueError('ReduceLROnPlateau callback is not supported with ' + 'DistributionStrategy.') + + # If users want to use the TensorBoard callback they cannot use certain + # features of the callback that involve accessing model attributes and + # running ops. + if isinstance(callback, callbacks.TensorBoard): + if callback.__getattribute__('histogram_freq'): + raise ValueError('histogram_freq in the TensorBoard callback is not ' + 'supported when using DistributionStrategy.') + if callback.__getattribute__('write_grads'): + raise ValueError('write_grads in the TensorBoard callback is not ' + 'supported when using DistributionStrategy.') + + +def validate_distributed_dataset_inputs(distribution_strategy, x, y): + """Validate all the components of a DistributedValue Dataset input. + + Args: + distribution_strategy: The current DistributionStrategy using to call + `fit`/`evaluate`. + x: Input Dataset DistributedValue object. For example, when we use + `MirroredStrategy` this is a PerDevice object with a tensor for each + device set in the dict. + y: Target Dataset DistributedValue object. For example, when we use + `MirroredStrategy` this is a PerDevice object with a tensor for each + device set in the dict. + + Returns: + The unwrapped values list of the x and y DistributedValues inputs. + + Raises: + ValueError: If x and y do not have support for being evaluated as tensors. + or if x and y contain elements that are not tensors or if x and y + contain elements that have a shape or dtype mismatch. + """ + # If the input and target used to call the model are not dataset tensors, + # we need to raise an error. When using a DistributionStrategy, the input + # and targets to a model should be from a `tf.data.Dataset`. + + # If each element of x and y are not tensors, we cannot standardize and + # validate the input and targets.` + if not tensor_util.is_tensor(x): + raise ValueError('Dataset input to the model should be tensors instead they' + ' are of type {}'.format(type(x))) + + if not tensor_util.is_tensor(y): + raise ValueError('Dataset input to the model should be tensors instead they' + ' are of type {}'.format(type(y))) + + # At this point both x and y contain tensors in the `DistributedValues` + # structure. + x_values = distribution_strategy.unwrap(x) + y_values = distribution_strategy.unwrap(y) + + # Validate that the shape and dtype of all the elements in x are the same. + validate_all_tensor_shapes(x, x_values) + validate_all_tensor_types(x, x_values) + + # Similarly for y, we perform the same validation + validate_all_tensor_shapes(y, y_values) + validate_all_tensor_types(y, y_values) + + # Return the unwrapped values to avoid calling `unwrap` a second time. + return x_values, y_values + + +def validate_all_tensor_types(x, x_values): + x_dtype = x_values[0].dtype + for i in range(1, len(x_values)): + if x_dtype != x_values[i].dtype: + raise ValueError('Input tensor dtypes do not match for distributed tensor' + ' inputs {}'.format(x)) + + +def validate_all_tensor_shapes(x, x_values): + # Validate that the shape of all the elements in x have the same shape + x_shape = x_values[0].get_shape().as_list() + for i in range(1, len(x_values)): + if x_shape != x_values[i].get_shape().as_list(): + raise ValueError('Input tensor shapes do not match for distributed tensor' + ' inputs {}'.format(x)) |