diff options
Diffstat (limited to 'tensorflow/python/keras/utils/data_utils.py')
-rw-r--r-- | tensorflow/python/keras/utils/data_utils.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index c1ee34ae46..b736daa46d 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -40,6 +40,7 @@ from six.moves.urllib.error import URLError from six.moves.urllib.request import urlopen from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -93,6 +94,11 @@ else: from six.moves.urllib.request import urlretrieve +def is_generator_or_sequence(x): + """Check if `x` is a Keras generator type.""" + return tf_inspect.isgenerator(x) or isinstance(x, Sequence) + + def _extract_archive(file_path, path='.', archive_format='auto'): """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. @@ -494,6 +500,7 @@ class SequenceEnqueuer(object): raise NotImplementedError +@tf_export('keras.utils.OrderedEnqueuer') class OrderedEnqueuer(SequenceEnqueuer): """Builds a Enqueuer from a Sequence. @@ -550,7 +557,7 @@ class OrderedEnqueuer(SequenceEnqueuer): self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda workers, initializer=init_pool, initargs=(seqs,)) else: - # We do not need the init since it's threads. + # We do not need the init since it's threads. self.executor_fn = lambda _: ThreadPool(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) |