aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/generator_io.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
new file mode 100644
index 0000000000..5859bb6b47
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
@@ -0,0 +1,134 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Methods to allow generator of dict with numpy arrays."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from types import FunctionType, GeneratorType
+from collections import Container
+
+from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions
+
+
+def generator_input_fn(x,
+ target_key=None,
+ batch_size=128,
+ num_epochs=1,
+ shuffle=True,
+ queue_capacity=1000,
+ num_threads=1):
+ """Returns input function that would dicts of numpy arrays
+ yielded from a generator.
+
+ It is assumed that every dict yielded from the dictionary represents
+ a single sample. The generator should consume a single epoch of the data.
+
+ This returns a function outputting `features` and `target` based on the dict
+ of numpy arrays. The dict `features` has the same keys as an element yielded
+ from x.
+
+ Example:
+ ```python
+ def generator():
+ for index in range(10):
+ yield {'height': np.random.randint(32,36),
+ 'age': np.random.randint(18, 80),
+ 'label': np.ones(1)}
+
+ with tf.Session() as session:
+ input_fn = generator_io.generator_input_fn(
+ generator, target_key="label", batch_size=2, shuffle=False,
+ num_epochs=1)
+ ```
+
+ Args:
+ x: Generator Function, returns a `Generator` that will yield the data
+ in `dict` of numpy arrays
+ target_key: String or Container of Strings, the key or Container of keys of
+ the numpy arrays in x dictionaries to use as target.
+ batch_size: Integer, size of batches to return.
+ num_epochs: Integer, number of epochs to iterate over data. If `None` will
+ run forever.
+ shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction
+ time.
+ queue_capacity: Integer, size of queue to accumulate.
+ num_threads: Integer, number of threads used for reading and enqueueing.
+
+ Returns:
+ Function, that returns a feature `dict` with `Tensors` and an optional
+ label `dict` with `Tensors`, or if target_key is `str` label is a `Tensor`
+
+ Raises:
+ TypeError: `x` is not `FunctionType`.
+ TypeError: `x()` is not `GeneratorType`.
+ TypeError: `next(x())` is not `dict`.
+ TypeError: `target_key` is not `str` or `target_key` is not `Container`
+ of `str`.
+ KeyError: `target_key` not a key or `target_key[index]` not in next(`x()`).
+ KeyError: `key` mismatch between dicts emitted from `x()`
+ """
+ if not isinstance(x, FunctionType):
+ raise TypeError(
+ 'x must be generator function; got {}'.format(type(x).__name__))
+ generator = x()
+ if not isinstance(generator, GeneratorType):
+ raise TypeError(
+ 'x() must be generator; got {}'.format(type(generator).__name__))
+ data = next(generator)
+ if not isinstance(data, dict):
+ raise TypeError('x() must yield dict; got {}'.format(type(data).__name__))
+ input_keys = sorted(next(x()).keys())
+ if target_key is not None:
+ if isinstance(target_key, str):
+ target_key = [target_key]
+ elif isinstance(target_key, Container):
+ for item in target_key:
+ if not isinstance(item, str):
+ raise TypeError('target_key must be str or Container of str; got {}'.
+ format(type(item).__name__))
+ if item not in input_keys:
+ raise KeyError(
+ 'target_key not in yielded dict. Expected {} keys; got {}'.format(
+ input_keys, item))
+ else:
+ raise TypeError('target_key must be str or Container of str; got {}'.
+ format(type(target_key).__name__))
+
+ def _generator_input_fn():
+ """generator input function."""
+ queue = feeding_functions.enqueue_data(
+ x,
+ queue_capacity,
+ shuffle=shuffle,
+ num_threads=num_threads,
+ enqueue_size=batch_size,
+ num_epochs=num_epochs)
+
+ features = (queue.dequeue_many(batch_size)
+ if num_epochs is None else queue.dequeue_up_to(batch_size))
+ if not isinstance(features, list):
+ features = [features]
+ features = dict(zip(input_keys, features))
+ if target_key is not None:
+ if len(target_key) > 1:
+ target = {key: features.pop(key) for key in target_key}
+ else:
+ target = features.pop(target_key[0])
+ return features, target
+ return features
+
+ return _generator_input_fn