diff options
author | Shanqing Cai <cais@google.com> | 2018-10-02 20:04:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 20:09:17 -0700 |
commit | 9f42ebd5982688511ecc0ef7d23de02b64d8dd1e (patch) | |
tree | 374d5b0e449bea65178b0b94fec711890d4d3adb | |
parent | 2597b883a14749c77fffd7e5f9677107021ff40a (diff) |
Improve error messages and doc strings for eager-mode tf.keras.Model.fit() + tf.data objects
- Previously, when validation_steps was missing, the error message incorrectly says "please provide either batch_size or steps_per_epoch". Now it reads "please provide either batch_size or validation_steps".
- Some whitespace-related fixes.
PiperOrigin-RevId: 215503991
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 9 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_eager.py | 3 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_eager_test.py | 30 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_utils.py | 15 |
4 files changed, 49 insertions, 8 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index c842b8192e..85233de9b1 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1419,6 +1419,8 @@ class Model(Network): - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - dataset or a dataset iterator + For the first two cases, `batch_size` must be provided. + For the last case, `validation_steps` must be provided. shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the @@ -1454,9 +1456,10 @@ class Model(Network): TensorFlow data tensors, the default `None` is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. - validation_steps: Only relevant if `steps_per_epoch` - is specified. Total number of steps (batches of samples) - to validate before stopping. + validation_steps: Only relevant if `validation_data` is provided and + is a dataset or dataset iterator. Total number of steps (batches of + samples) to draw before stopping when performing validation + at the end of every epoch. max_queue_size: Integer. Used for generator or `keras.utils.Sequence` input only. Maximum size for the generator queue. If unspecified, `max_queue_size` will default to 10. diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index fb71bf2596..2a62edd698 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -739,7 +739,8 @@ def test_loop(model, inputs, targets, y=targets, sample_weights=sample_weights, batch_size=batch_size, - steps_per_epoch=steps) + steps_per_epoch=steps, + is_validation=True) with backend.learning_phase_scope(0): return iterator_test_loop(model, inputs, steps, verbose=verbose) diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 1f5176c4d7..943ede1be9 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -125,6 +125,36 @@ class TrainingTest(test.TestCase): model.train_on_batch(inputs, targets) model.test_on_batch(inputs, targets) + def test_model_fit_and_validation_with_missing_arg_errors(self): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') + + x = keras.backend.zeros(shape=(10, 3)) + y = keras.backend.zeros(shape=(10, 4)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5) + iterator = dataset.make_one_shot_iterator() + validation_dataset = dataset_ops.Dataset.from_tensor_slices( + (x, y)).repeat(10).batch(5) + validation_iterator = validation_dataset.make_one_shot_iterator() + + with self.assertRaisesRegexp( + ValueError, r'specify .* `steps_per_epoch`'): + model.fit(iterator, epochs=1, verbose=0) + with self.assertRaisesRegexp( + ValueError, r'provide either `batch_size` or `validation_steps`'): + model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, + validation_data=(x, y)) + with self.assertRaisesRegexp( + ValueError, r'provide either `batch_size` or `validation_steps`'): + model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, + validation_data=validation_dataset) + with self.assertRaisesRegexp( + ValueError, r'provide either `batch_size` or `validation_steps`'): + model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, + validation_data=validation_iterator) + def test_generator_methods(self): model = keras.Sequential() model.add(keras.layers.Dense(4, input_shape=(3,))) diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 9c303f4bed..dd2a7f16ec 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -106,7 +106,8 @@ def convert_to_iterator(x=None, batch_size=None, steps_per_epoch=None, epochs=1, - shuffle=False): + shuffle=False, + is_validation=False): """Converts NumPy arrays or EagerTensors to an EagerIterator. Combines all provided data into a single EagerIterator. @@ -124,6 +125,9 @@ def convert_to_iterator(x=None, epoch. epochs: Epochs to repeat iterator for. shuffle: Whether to shuffle data after each epoch. + is_validation: Whether this call is for validation during a training + (e.g., `fit()`) call. This info is used to construct error messages + (if any). Raises: ValueError: if steps_per_epoch cannot be calculated from the data @@ -151,9 +155,12 @@ def convert_to_iterator(x=None, steps_per_epoch = int(math.ceil(num_samples / batch_size)) if steps_per_epoch is None: - raise ValueError('Could not determine steps_per_epoch.' - 'Please provide either batch_size or' - 'steps_per_epoch.') + alternative_arg_name = ( + 'validation_steps' if is_validation else 'steps_per_epoch') + raise ValueError( + 'Could not determine how to convert EagerTensors into EagerIterator. ' + 'Please provide either `batch_size` or ' + '`%s`.' % alternative_arg_name) # TODO(omalleyt) for NumPy arrays in graph mode # placeholder ops should be used |