aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-05-24 11:11:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 11:14:25 -0700
commit677b4cf7539af0cf5741d12dfe7e142c586d4567 (patch)
tree144325983f817aa29198cfd89d71f30b60fce723
parent015c1d84f714c651f401a19cdb709ad9c91561e1 (diff)
Add shape validation for symbolic tensors passed to fit (only graph mode).
PiperOrigin-RevId: 197921675
-rw-r--r--tensorflow/python/keras/engine/training.py3
-rw-r--r--tensorflow/python/keras/engine/training_test.py31
-rw-r--r--tensorflow/python/keras/engine/training_utils.py14
3 files changed, 43 insertions, 5 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 0db805cc84..6d625f16c2 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -846,7 +846,8 @@ class Model(Network):
# in the case where all inputs are value arrays.
if context.executing_eagerly():
- # In eager mode, do not do shape validation.
+ # In eager mode, do not do shape validation
+ # since the network has no input nodes (placeholders) to be fed.
feed_input_names = self.input_names
feed_input_shapes = None
elif not self._is_graph_network:
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 222e3496c1..5c02d36382 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1917,6 +1917,37 @@ class TestTrainingWithDataset(test.TestCase):
'you should specify the `steps` argument'):
model.predict(dataset, verbose=0)
+ def test_dataset_input_shape_validation(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss)
+
+ # User forgets to batch the dataset
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have 2 dimensions'):
+ model.train_on_batch(dataset)
+
+ # Wrong input shape
+ inputs = np.zeros((10, 5), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have shape'):
+ model.train_on_batch(dataset)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index c53948b902..b93f999444 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -166,10 +166,16 @@ def standardize_input_data(data,
# Check shapes compatibility.
if shapes:
for i in range(len(names)):
- if shapes[i] is not None and not tensor_util.is_tensor(data[i]):
- data_shape = data[i].shape
+ if shapes[i] is not None:
+ if tensor_util.is_tensor(data[i]):
+ tensorshape = data[i].get_shape()
+ if not tensorshape:
+ continue
+ data_shape = tuple(tensorshape.as_list())
+ else:
+ data_shape = data[i].shape
shape = shapes[i]
- if data[i].ndim != len(shape):
+ if len(data_shape) != len(shape):
raise ValueError('Error when checking ' + exception_prefix +
': expected ' + names[i] + ' to have ' +
str(len(shape)) + ' dimensions, but got array '
@@ -178,7 +184,7 @@ def standardize_input_data(data,
data_shape = data_shape[1:]
shape = shape[1:]
for dim, ref_dim in zip(data_shape, shape):
- if ref_dim != dim and ref_dim:
+ if ref_dim != dim and ref_dim is not None and dim is not None:
raise ValueError(
'Error when checking ' + exception_prefix + ': expected ' +
names[i] + ' to have shape ' + str(shape) +