aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-10-02 20:04:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 20:09:17 -0700
commit9f42ebd5982688511ecc0ef7d23de02b64d8dd1e (patch)
tree374d5b0e449bea65178b0b94fec711890d4d3adb /tensorflow/python/keras
parent2597b883a14749c77fffd7e5f9677107021ff40a (diff)
Improve error messages and doc strings for eager-mode tf.keras.Model.fit() + tf.data objects
- Previously, when validation_steps was missing, the error message incorrectly says "please provide either batch_size or steps_per_epoch". Now it reads "please provide either batch_size or validation_steps". - Some whitespace-related fixes. PiperOrigin-RevId: 215503991
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training.py9
-rw-r--r--tensorflow/python/keras/engine/training_eager.py3
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py30
-rw-r--r--tensorflow/python/keras/engine/training_utils.py15
4 files changed, 49 insertions, 8 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index c842b8192e..85233de9b1 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1419,6 +1419,8 @@ class Model(Network):
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- dataset or a dataset iterator
+ For the first two cases, `batch_size` must be provided.
+ For the last case, `validation_steps` must be provided.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
@@ -1454,9 +1456,10 @@ class Model(Network):
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
- validation_steps: Only relevant if `steps_per_epoch`
- is specified. Total number of steps (batches of samples)
- to validate before stopping.
+ validation_steps: Only relevant if `validation_data` is provided and
+ is a dataset or dataset iterator. Total number of steps (batches of
+ samples) to draw before stopping when performing validation
+ at the end of every epoch.
max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
input only. Maximum size for the generator queue.
If unspecified, `max_queue_size` will default to 10.
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index fb71bf2596..2a62edd698 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -739,7 +739,8 @@ def test_loop(model, inputs, targets,
y=targets,
sample_weights=sample_weights,
batch_size=batch_size,
- steps_per_epoch=steps)
+ steps_per_epoch=steps,
+ is_validation=True)
with backend.learning_phase_scope(0):
return iterator_test_loop(model, inputs, steps, verbose=verbose)
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index 1f5176c4d7..943ede1be9 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -125,6 +125,36 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_model_fit_and_validation_with_missing_arg_errors(self):
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
+
+ x = keras.backend.zeros(shape=(10, 3))
+ y = keras.backend.zeros(shape=(10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
+ iterator = dataset.make_one_shot_iterator()
+ validation_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (x, y)).repeat(10).batch(5)
+ validation_iterator = validation_dataset.make_one_shot_iterator()
+
+ with self.assertRaisesRegexp(
+ ValueError, r'specify .* `steps_per_epoch`'):
+ model.fit(iterator, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=(x, y))
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=validation_dataset)
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=validation_iterator)
+
def test_generator_methods(self):
model = keras.Sequential()
model.add(keras.layers.Dense(4, input_shape=(3,)))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 9c303f4bed..dd2a7f16ec 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -106,7 +106,8 @@ def convert_to_iterator(x=None,
batch_size=None,
steps_per_epoch=None,
epochs=1,
- shuffle=False):
+ shuffle=False,
+ is_validation=False):
"""Converts NumPy arrays or EagerTensors to an EagerIterator.
Combines all provided data into a single EagerIterator.
@@ -124,6 +125,9 @@ def convert_to_iterator(x=None,
epoch.
epochs: Epochs to repeat iterator for.
shuffle: Whether to shuffle data after each epoch.
+ is_validation: Whether this call is for validation during a training
+ (e.g., `fit()`) call. This info is used to construct error messages
+ (if any).
Raises:
ValueError: if steps_per_epoch cannot be calculated from the data
@@ -151,9 +155,12 @@ def convert_to_iterator(x=None,
steps_per_epoch = int(math.ceil(num_samples / batch_size))
if steps_per_epoch is None:
- raise ValueError('Could not determine steps_per_epoch.'
- 'Please provide either batch_size or'
- 'steps_per_epoch.')
+ alternative_arg_name = (
+ 'validation_steps' if is_validation else 'steps_per_epoch')
+ raise ValueError(
+ 'Could not determine how to convert EagerTensors into EagerIterator. '
+ 'Please provide either `batch_size` or '
+ '`%s`.' % alternative_arg_name)
# TODO(omalleyt) for NumPy arrays in graph mode
# placeholder ops should be used