aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/distributed_training_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/distributed_training_utils.py')
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py77
1 files changed, 76 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index c1c4970025..b28df75493 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
@@ -212,7 +213,10 @@ def validate_distributed_dataset_inputs(distribution_strategy, x, y):
# validate the input and targets.
x_values_list = validate_per_device_inputs(distribution_strategy, x)
- y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ if y is not None:
+ y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ else:
+ y_values_list = None
# Return the unwrapped values to avoid calling `unwrap` a second time.
return x_values_list, y_values_list
@@ -287,3 +291,74 @@ def configure_and_create_session(distribution_strategy):
session = session_module.Session(config=session_config)
K.set_session(session)
+
+
+def validate_inputs(x, y):
+ """Validate inputs when using DistributionStrategy.
+
+ Args:
+ x: Model Inputs.
+ y: Model Targets.
+
+ Raises:
+ ValueError: if input is not a Dataset or a numpy array.
+ """
+ if isinstance(x, list) or isinstance(y, list):
+ raise ValueError('DistributionStrategy does not support lists of numpy'
+ 'arrays. You must pass a Dataset object or a numpy array '
+ 'as input.')
+
+ if isinstance(x, dict) or isinstance(y, dict):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'dict. You must pass a Dataset object or a numpy array as '
+ 'input.')
+
+ if isinstance(x, iterator_ops.Iterator) or \
+ isinstance(y, iterator_ops.Iterator):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'Iterator. You must pass a Dataset object or a numpy '
+ 'array as input.')
+
+
+def get_input_batch_params(first_x_value, batch_size, current_strategy):
+ """Calculate the number of batches and steps/steps_per_epoch.
+
+ Args:
+ first_x_value: This is the first input numpy array that is passed in as the
+ model input.
+ batch_size: The specified batch_size or the default batch_size of 32.
+ current_strategy: The current DistributionStrategy used to compile the
+ model.
+
+ Returns:
+ The steps or steps_per_epoch argument depending on if a user is
+ calling `fit`, `evaluate` or `predict`.
+
+ Raises:
+ ValueError: If the number of batches or steps evaluates to 0.
+
+ """
+ num_batches = first_x_value.shape[0] // batch_size
+ if not num_batches:
+ raise ValueError('Please specify a batch_size that is smaller than'
+ 'the number of input samples %d.' % first_x_value.shape[0])
+ # TODO(anjalisridhar): TPU currently supports using the num_towers property.
+ # We might want to look into implementing worker_devices. In multi worker
+ # strategy, perhaps num_towers works better?
+ steps = num_batches // current_strategy.num_towers
+ if not steps:
+ # TODO(anjalisridhar): Number of towers in the error message may not convey
+ # what we want to the user. Is there another terminology that we can use
+ # that is consistent across different strategies.
+ raise ValueError('The number of batches %d is smaller than the number '
+ 'of towers %d used for DistributionStrategy. ' %
+ num_batches, current_strategy.num_towers)
+ return steps
+
+
+def get_batch_dimension(iterator):
+ shapes = nest.flatten(iterator.output_shapes)
+ # Take the batch size from the first element, as it should be the same for
+ # all.
+ dims = shapes[0].dims
+ return dims[0] if dims else None