aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_utils.py')
-rw-r--r--tensorflow/python/keras/engine/training_utils.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 898e9223cb..8e9fab81d6 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -797,6 +797,18 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: x=%s, validation_split=%f' % (x, validation_split))
+def check_generator_arguments(y=None, sample_weight=None):
+ """Validates arguments passed when using a generator."""
+ if y is not None:
+ raise ValueError('`y` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass targets'
+ ' as the second element of the generator.')
+ if sample_weight is not None:
+ raise ValueError('`sample_weight` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass sample'
+ ' weights as the third element of the generator.')
+
+
def check_steps_argument(input_data, steps, steps_name):
"""Validates `steps` argument based on input data's type.