diff options
Diffstat (limited to 'tensorflow/python/keras/engine/distributed_training_utils.py')
-rw-r--r-- | tensorflow/python/keras/engine/distributed_training_utils.py | 77 |
1 files changed, 76 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index c1c4970025..b28df75493 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.client import session as session_module +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks @@ -212,7 +213,10 @@ def validate_distributed_dataset_inputs(distribution_strategy, x, y): # validate the input and targets. x_values_list = validate_per_device_inputs(distribution_strategy, x) - y_values_list = validate_per_device_inputs(distribution_strategy, y) + if y is not None: + y_values_list = validate_per_device_inputs(distribution_strategy, y) + else: + y_values_list = None # Return the unwrapped values to avoid calling `unwrap` a second time. return x_values_list, y_values_list @@ -287,3 +291,74 @@ def configure_and_create_session(distribution_strategy): session = session_module.Session(config=session_config) K.set_session(session) + + +def validate_inputs(x, y): + """Validate inputs when using DistributionStrategy. + + Args: + x: Model Inputs. + y: Model Targets. + + Raises: + ValueError: if input is not a Dataset or a numpy array. + """ + if isinstance(x, list) or isinstance(y, list): + raise ValueError('DistributionStrategy does not support lists of numpy' + 'arrays. You must pass a Dataset object or a numpy array ' + 'as input.') + + if isinstance(x, dict) or isinstance(y, dict): + raise ValueError('DistributionStrategy does not support inputs of type ' + 'dict. You must pass a Dataset object or a numpy array as ' + 'input.') + + if isinstance(x, iterator_ops.Iterator) or \ + isinstance(y, iterator_ops.Iterator): + raise ValueError('DistributionStrategy does not support inputs of type ' + 'Iterator. You must pass a Dataset object or a numpy ' + 'array as input.') + + +def get_input_batch_params(first_x_value, batch_size, current_strategy): + """Calculate the number of batches and steps/steps_per_epoch. + + Args: + first_x_value: This is the first input numpy array that is passed in as the + model input. + batch_size: The specified batch_size or the default batch_size of 32. + current_strategy: The current DistributionStrategy used to compile the + model. + + Returns: + The steps or steps_per_epoch argument depending on if a user is + calling `fit`, `evaluate` or `predict`. + + Raises: + ValueError: If the number of batches or steps evaluates to 0. + + """ + num_batches = first_x_value.shape[0] // batch_size + if not num_batches: + raise ValueError('Please specify a batch_size that is smaller than' + 'the number of input samples %d.' % first_x_value.shape[0]) + # TODO(anjalisridhar): TPU currently supports using the num_towers property. + # We might want to look into implementing worker_devices. In multi worker + # strategy, perhaps num_towers works better? + steps = num_batches // current_strategy.num_towers + if not steps: + # TODO(anjalisridhar): Number of towers in the error message may not convey + # what we want to the user. Is there another terminology that we can use + # that is consistent across different strategies. + raise ValueError('The number of batches %d is smaller than the number ' + 'of towers %d used for DistributionStrategy. ' % + num_batches, current_strategy.num_towers) + return steps + + +def get_batch_dimension(iterator): + shapes = nest.flatten(iterator.output_shapes) + # Take the batch size from the first element, as it should be the same for + # all. + dims = shapes[0].dims + return dims[0] if dims else None |