aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py23
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py2
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py16
-rw-r--r--tensorflow/python/keras/engine/training.py12
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py2
5 files changed, 47 insertions, 8 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 8165a70743..2e6cd43fd4 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -635,6 +635,29 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.tpu_strategy_one_step],
+ mode=['graph']))
+ def test_dataset_input_shape_fully_defined(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ # Input shapes are not fully known. Batch dimension is unknown as we are
+ # not using the drop_remainder argument.
+ dataset = dataset.repeat(100).batch(10)
+
+ with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
def test_learning_phase_value(self):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index ba2cc2e806..a6762e5e87 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -158,7 +158,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ 'dataset.batch(..., drop_remainder=True).')
types = nest.flatten(iterator.output_types)
enqueue_ops = [
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index b28df75493..39341a931b 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
@@ -293,12 +294,14 @@ def configure_and_create_session(distribution_strategy):
K.set_session(session)
-def validate_inputs(x, y):
+def validate_inputs(x, y, distribution_strategy):
"""Validate inputs when using DistributionStrategy.
Args:
x: Model Inputs.
y: Model Targets.
+ distribution_strategy: The DistributionStrategy with which the model is
+ compiled.
Raises:
ValueError: if input is not a Dataset or a numpy array.
@@ -319,6 +322,17 @@ def validate_inputs(x, y):
'Iterator. You must pass a Dataset object or a numpy '
'array as input.')
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ for i in [x, y]:
+ if isinstance(i, dataset_ops.Dataset):
+ shapes = nest.flatten(i.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'Using TPUs currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.batch(..., drop_remainder=True).'
+ 'Found unknown shape {} in input {}.'.format(s, i))
+
def get_input_batch_params(first_x_value, batch_size, current_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 154c219dcc..ade8a4b32d 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1521,7 +1521,8 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
@@ -1563,7 +1564,8 @@ class Model(Network):
# Validate and standardize validation data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(val_x, val_y)
+ distributed_training_utils.validate_inputs(
+ val_x, val_y, self._distribution_strategy)
first_valx_value = nest.flatten(val_x)[0]
if not validation_steps and isinstance(first_valx_value, np.ndarray):
validation_steps = distributed_training_utils.get_input_batch_params(
@@ -1737,7 +1739,8 @@ class Model(Network):
# Validate and standardize user data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
@@ -1852,7 +1855,8 @@ class Model(Network):
# `MirroredStrategy`.
if hasattr(self._distribution_strategy, '_prefetch_on_device'):
self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
- distributed_training_utils.validate_inputs(x, None)
+ distributed_training_utils.validate_inputs(
+ x, None, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 26c5ec4efc..8b434ca444 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -233,8 +233,6 @@ def _experimental_fit_loop(
"""
current_strategy = model._distribution_strategy
- # TODO(priyag): Add validation that shapes are fully defined for TPU case.
-
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):