aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/data_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/data_utils.py')
-rw-r--r--tensorflow/python/keras/utils/data_utils.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index c1ee34ae46..b736daa46d 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -40,6 +40,7 @@ from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -93,6 +94,11 @@ else:
from six.moves.urllib.request import urlretrieve
+def is_generator_or_sequence(x):
+ """Check if `x` is a Keras generator type."""
+ return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
+
+
def _extract_archive(file_path, path='.', archive_format='auto'):
"""Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
@@ -494,6 +500,7 @@ class SequenceEnqueuer(object):
raise NotImplementedError
+@tf_export('keras.utils.OrderedEnqueuer')
class OrderedEnqueuer(SequenceEnqueuer):
"""Builds a Enqueuer from a Sequence.
@@ -550,7 +557,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
workers, initializer=init_pool, initargs=(seqs,))
else:
- # We do not need the init since it's threads.
+ # We do not need the init since it's threads.
self.executor_fn = lambda _: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)