aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-09-24 20:22:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 20:29:54 -0700
commit6ba60e051409a5346c2aab21160c9c311de1cb03 (patch)
tree955be96a46d13601582343a25ae3612ad53179d7 /tensorflow/python/keras
parent4dc77744ff6a6854cf4aa2934eb4501bc22c3465 (diff)
Add validation that input shapes should be fully defined when using TPU strategy with keras.
PiperOrigin-RevId: 214376435
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py16
-rw-r--r--tensorflow/python/keras/engine/training.py12
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py2
3 files changed, 23 insertions, 7 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index b28df75493..39341a931b 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
@@ -293,12 +294,14 @@ def configure_and_create_session(distribution_strategy):
K.set_session(session)
-def validate_inputs(x, y):
+def validate_inputs(x, y, distribution_strategy):
"""Validate inputs when using DistributionStrategy.
Args:
x: Model Inputs.
y: Model Targets.
+ distribution_strategy: The DistributionStrategy with which the model is
+ compiled.
Raises:
ValueError: if input is not a Dataset or a numpy array.
@@ -319,6 +322,17 @@ def validate_inputs(x, y):
'Iterator. You must pass a Dataset object or a numpy '
'array as input.')
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ for i in [x, y]:
+ if isinstance(i, dataset_ops.Dataset):
+ shapes = nest.flatten(i.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'Using TPUs currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.batch(..., drop_remainder=True).'
+ 'Found unknown shape {} in input {}.'.format(s, i))
+
def get_input_batch_params(first_x_value, batch_size, current_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 154c219dcc..ade8a4b32d 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1521,7 +1521,8 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
@@ -1563,7 +1564,8 @@ class Model(Network):
# Validate and standardize validation data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(val_x, val_y)
+ distributed_training_utils.validate_inputs(
+ val_x, val_y, self._distribution_strategy)
first_valx_value = nest.flatten(val_x)[0]
if not validation_steps and isinstance(first_valx_value, np.ndarray):
validation_steps = distributed_training_utils.get_input_batch_params(
@@ -1737,7 +1739,8 @@ class Model(Network):
# Validate and standardize user data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
@@ -1852,7 +1855,8 @@ class Model(Network):
# `MirroredStrategy`.
if hasattr(self._distribution_strategy, '_prefetch_on_device'):
self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
- distributed_training_utils.validate_inputs(x, None)
+ distributed_training_utils.validate_inputs(
+ x, None, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 26c5ec4efc..8b434ca444 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -233,8 +233,6 @@ def _experimental_fit_loop(
"""
current_strategy = model._distribution_strategy
- # TODO(priyag): Add validation that shapes are fully defined for TPU case.
-
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):