diff options
author | Priya Gupta <priyag@google.com> | 2018-09-24 20:22:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 20:29:54 -0700 |
commit | 6ba60e051409a5346c2aab21160c9c311de1cb03 (patch) | |
tree | 955be96a46d13601582343a25ae3612ad53179d7 /tensorflow/python/keras | |
parent | 4dc77744ff6a6854cf4aa2934eb4501bc22c3465 (diff) |
Add validation that input shapes should be fully defined when using TPU strategy with keras.
PiperOrigin-RevId: 214376435
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/distributed_training_utils.py | 16 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 12 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 2 |
3 files changed, 23 insertions, 7 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index b28df75493..39341a931b 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.client import session as session_module +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K @@ -293,12 +294,14 @@ def configure_and_create_session(distribution_strategy): K.set_session(session) -def validate_inputs(x, y): +def validate_inputs(x, y, distribution_strategy): """Validate inputs when using DistributionStrategy. Args: x: Model Inputs. y: Model Targets. + distribution_strategy: The DistributionStrategy with which the model is + compiled. Raises: ValueError: if input is not a Dataset or a numpy array. @@ -319,6 +322,17 @@ def validate_inputs(x, y): 'Iterator. You must pass a Dataset object or a numpy ' 'array as input.') + if distribution_strategy.__class__.__name__ == 'TPUStrategy': + for i in [x, y]: + if isinstance(i, dataset_ops.Dataset): + shapes = nest.flatten(i.output_shapes) + if any([not s.is_fully_defined() for s in shapes]): + raise ValueError( + 'Using TPUs currently requires fully defined shapes. Either use ' + 'set_shape() on the input tensors or use ' + 'dataset.batch(..., drop_remainder=True).' + 'Found unknown shape {} in input {}.'.format(s, i)) + def get_input_batch_params(first_x_value, batch_size, current_strategy): """Calculate the number of batches and steps/steps_per_epoch. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 154c219dcc..ade8a4b32d 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1521,7 +1521,8 @@ class Model(Network): if self._distribution_strategy: distributed_training_utils.validate_callbacks(callbacks) - distributed_training_utils.validate_inputs(x, y) + distributed_training_utils.validate_inputs( + x, y, self._distribution_strategy) first_x_value = nest.flatten(x)[0] if not steps_per_epoch and isinstance(first_x_value, np.ndarray): @@ -1563,7 +1564,8 @@ class Model(Network): # Validate and standardize validation data. if self._distribution_strategy: - distributed_training_utils.validate_inputs(val_x, val_y) + distributed_training_utils.validate_inputs( + val_x, val_y, self._distribution_strategy) first_valx_value = nest.flatten(val_x)[0] if not validation_steps and isinstance(first_valx_value, np.ndarray): validation_steps = distributed_training_utils.get_input_batch_params( @@ -1737,7 +1739,8 @@ class Model(Network): # Validate and standardize user data. if self._distribution_strategy: - distributed_training_utils.validate_inputs(x, y) + distributed_training_utils.validate_inputs( + x, y, self._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray) and not steps: steps = distributed_training_utils.get_input_batch_params( @@ -1852,7 +1855,8 @@ class Model(Network): # `MirroredStrategy`. if hasattr(self._distribution_strategy, '_prefetch_on_device'): self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access - distributed_training_utils.validate_inputs(x, None) + distributed_training_utils.validate_inputs( + x, None, self._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray) and not steps: steps = distributed_training_utils.get_input_batch_params( diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 26c5ec4efc..8b434ca444 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -233,8 +233,6 @@ def _experimental_fit_loop( """ current_strategy = model._distribution_strategy - # TODO(priyag): Add validation that shapes are fully defined for TPU case. - K.get_session().run(current_strategy.initialize()) def _per_device_train_function(model): |