diff options
author | Jianwei Xie <xiejw@google.com> | 2017-02-25 13:31:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-25 14:02:04 -0800 |
commit | 0431511cedec7b3173576399de951c7e11360c4a (patch) | |
tree | fe74210504bdb414867d38275025d05b1278492f /tensorflow/python/estimator/inputs | |
parent | 07427d1b51713a085f06c62b203799490591ed80 (diff) |
Move numpy_input_fn and pandas_input_fn from contrib to core.
Change: 148560715
Diffstat (limited to 'tensorflow/python/estimator/inputs')
11 files changed, 1754 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/inputs/__init__.py b/tensorflow/python/estimator/inputs/__init__.py new file mode 100644 index 0000000000..c7bfbf562d --- /dev/null +++ b/tensorflow/python/estimator/inputs/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Methods to create input_fn.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py new file mode 100644 index 0000000000..e25e29bd9c --- /dev/null +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Methods to allow dict of numpy arrays.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from tensorflow.python.estimator.inputs.queues import feeding_functions + +# Key name to pack the target into dict of `features`. See +# `_get_unique_target_key` for details. +_TARGET_KEY = '__target_key__' + + +def _get_unique_target_key(features): + """Returns a key not existed in the input dict `features`. + + Caller of `input_fn` usually provides `features` (dict of numpy arrays) and + `target`, but the underlying feeding module expects a single dict of numpy + arrays as input. So, the `target` needs to be packed into the `features` + temporarily and unpacked after calling the feeding function. Toward this goal, + this function returns a key not existed in the `features` to pack the + `target`. + """ + target_key = _TARGET_KEY + while target_key in features: + target_key += '_n' + return target_key + + +def numpy_input_fn(x, + y=None, + batch_size=128, + num_epochs=1, + shuffle=True, + queue_capacity=1000, + num_threads=1): + """Returns input function that would feed dict of numpy arrays into the model. + + This returns a function outputting `features` and `target` based on the dict + of numpy arrays. The dict `features` has the same keys as the `x`. + + Example: + ```python + age = np.arange(4) * 1.0 + height = np.arange(32, 36) + x = {'age': age, 'height': height} + y = np.arange(-32, -28) + + with tf.Session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + ``` + + Args: + x: dict of numpy array object. + y: numpy array object. + batch_size: Integer, size of batches to return. + num_epochs: Integer, number of epochs to iterate over data. If `None` will + run forever. + shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction + time. + queue_capacity: Integer, size of queue to accumulate. + num_threads: Integer, number of threads used for reading and enqueueing. + + Returns: + Function, that has signature of ()->(dict of `features`, `target`) + + Raises: + ValueError: if the shape of `y` mismatches the shape of values in `x` (i.e., + values in `x` have same shape). + TypeError: `x` is not a dict. + """ + + def input_fn(): + """Numpy input function.""" + if not isinstance(x, dict): + raise TypeError('x must be dict; got {}'.format(type(x).__name__)) + + unique_target_key = _get_unique_target_key(x) + if y is not None: + x[unique_target_key] = y + + if len(set(v.shape[0] for v in x.values())) != 1: + shape_dict_of_x = {k: x[k].shape for k in x.keys()} + shape_of_y = None if y is None else y.shape + raise ValueError('Length of tensors in x and y is mismatched. All ' + 'elements in x and y must have the same length.\n' + 'Shapes in x: {}\n' + 'Shape for y: {}\n'.format(shape_dict_of_x, shape_of_y)) + + # Ensure the order of iteration is consistent. + ordered_dict_x = collections.OrderedDict( + sorted(x.items(), key=lambda t: t[0])) + + queue = feeding_functions._enqueue_data( # pylint: disable=protected-access + ordered_dict_x, + queue_capacity, + shuffle=shuffle, + num_threads=num_threads, + enqueue_size=batch_size, + num_epochs=num_epochs) + + features = (queue.dequeue_many(batch_size) if num_epochs is None + else queue.dequeue_up_to(batch_size)) + + # Remove the first `Tensor` in `features`, which is the row number. + if len(features) > 0: + features.pop(0) + + features = dict(zip(ordered_dict_x.keys(), features)) + if y is not None: + target = features.pop(unique_target_key) + return features, target + return features + + return input_fn diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py new file mode 100644 index 0000000000..e30ce5515f --- /dev/null +++ b/tensorflow/python/estimator/inputs/numpy_io_test.py @@ -0,0 +1,280 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for numpy_io.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + + +class NumpyIoTest(test.TestCase): + + def testNumpyInputFn(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = np.arange(-32, -28) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33]) + self.assertAllEqual(res[1], [-32, -31]) + + session.run([features, target]) + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self): + a = np.arange(2) * 1.0 + b = np.arange(32, 34) + x = {'a': a, 'b': b} + y = np.arange(-32, -30) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=128, shuffle=False, num_epochs=2) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1, 0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33, 32, 33]) + self.assertAllEqual(res[1], [-32, -31, -32, -31]) + + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithZeroEpochs(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = np.arange(-32, -28) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=0) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self): + batch_size = 2 + a = np.arange(5) * 1.0 + b = np.arange(32, 37) + x = {'a': a, 'b': b} + y = np.arange(-32, -27) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=batch_size, shuffle=False, num_epochs=1) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33]) + self.assertAllEqual(res[1], [-32, -31]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [2, 3]) + self.assertAllEqual(res[0]['b'], [34, 35]) + self.assertAllEqual(res[1], [-30, -29]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [4]) + self.assertAllEqual(res[0]['b'], [36]) + self.assertAllEqual(res[1], [-28]) + + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self): + batch_size = 2 + a = np.arange(3) * 1.0 + b = np.arange(32, 35) + x = {'a': a, 'b': b} + y = np.arange(-32, -29) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=batch_size, shuffle=False, num_epochs=3) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33]) + self.assertAllEqual(res[1], [-32, -31]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [2, 0]) + self.assertAllEqual(res[0]['b'], [34, 32]) + self.assertAllEqual(res[1], [-30, -32]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [1, 2]) + self.assertAllEqual(res[0]['b'], [33, 34]) + self.assertAllEqual(res[1], [-31, -30]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33]) + self.assertAllEqual(res[1], [-32, -31]) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [2]) + self.assertAllEqual(res[0]['b'], [34]) + self.assertAllEqual(res[1], [-30]) + + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithBatchSizeLargerThanDataSize(self): + batch_size = 10 + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = np.arange(-32, -28) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=batch_size, shuffle=False, num_epochs=1) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [0, 1, 2, 3]) + self.assertAllEqual(res[0]['b'], [32, 33, 34, 35]) + self.assertAllEqual(res[1], [-32, -31, -30, -29]) + + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithDifferentDimensionsOfFeatures(self): + a = np.array([[1, 2], [3, 4]]) + b = np.array([5, 6]) + x = {'a': a, 'b': b} + y = np.arange(-32, -30) + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + features, target = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]]) + self.assertAllEqual(res[0]['b'], [5, 6]) + self.assertAllEqual(res[1], [-32, -31]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithXAsNonDict(self): + x = np.arange(32, 36) + y = np.arange(4) + with self.test_session(): + with self.assertRaisesRegexp(TypeError, 'x must be dict'): + failing_input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + failing_input_fn() + + def testNumpyInputFnWithTargetKeyAlreadyInX(self): + array = np.arange(32, 36) + x = {'__target_key__': array} + y = np.arange(4) + + with self.test_session(): + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + input_fn() + self.assertAllEqual(x['__target_key__'], array) + self.assertAllEqual(x['__target_key___n'], y) + + def testNumpyInputFnWithMismatchLengthOfInputs(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + x_mismatch_length = {'a': np.arange(1), 'b': b} + y_longer_length = np.arange(10) + + with self.test_session(): + with self.assertRaisesRegexp( + ValueError, 'Length of tensors in x and y is mismatched.'): + failing_input_fn = numpy_io.numpy_input_fn( + x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1) + failing_input_fn() + + with self.assertRaisesRegexp( + ValueError, 'Length of tensors in x and y is mismatched.'): + failing_input_fn = numpy_io.numpy_input_fn( + x=x_mismatch_length, + y=None, + batch_size=2, + shuffle=False, + num_epochs=1) + failing_input_fn() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/inputs/pandas_import.py b/tensorflow/python/estimator/inputs/pandas_import.py new file mode 100644 index 0000000000..6f78a16847 --- /dev/null +++ b/tensorflow/python/estimator/inputs/pandas_import.py @@ -0,0 +1,32 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles pandas import for tensorflow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as _ # pylint: disable=unused-import + +try: + # pylint: disable=g-import-not-at-top + # pylint: disable=unused-import + import pandas as _ + HAS_PANDAS = True +except IOError: + # Pandas writes a temporary file during import. If it fails, don't use pandas. + HAS_PANDAS = False +except ImportError: + HAS_PANDAS = False diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py new file mode 100644 index 0000000000..914845ba6a --- /dev/null +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Methods to allow pandas.DataFrame.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.estimator.inputs.queues import feeding_functions + + +def pandas_input_fn(x, + y=None, + batch_size=128, + num_epochs=1, + shuffle=True, + queue_capacity=1000, + num_threads=1, + target_column='target'): + """Returns input function that would feed Pandas DataFrame into the model. + + Note: `y`'s index must match `x`'s index. + + Args: + x: pandas `DataFrame` object. + y: pandas `Series` object. + batch_size: int, size of batches to return. + num_epochs: int, number of epochs to iterate over data. If not `None`, + read attempts that would exceed this value will raise `OutOfRangeError`. + shuffle: bool, whether to read the records in random order. + queue_capacity: int, size of the read queue. If `None`, it will be set + roughly to the size of `x`. + num_threads: int, number of threads used for reading and enqueueing. + target_column: str, name to give the target column `y`. + + Returns: + Function, that has signature of ()->(dict of `features`, `target`) + + Raises: + ValueError: if `x` already contains a column with the same name as `y`, or + if the indexes of `x` and `y` don't match. + """ + if not HAS_PANDAS: + raise TypeError( + 'pandas_input_fn should not be called without pandas installed') + + x = x.copy() + if y is not None: + if target_column in x: + raise ValueError( + 'Cannot use name %s for target column: DataFrame already has a ' + 'column with that name: %s' % (target_column, x.columns)) + if not np.array_equal(x.index, y.index): + raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n' + 'Index for y: %s\n' % (x.index, y.index)) + x[target_column] = y + + # TODO(mdan): These are memory copies. We probably don't need 4x slack space. + # The sizes below are consistent with what I've seen elsewhere. + if queue_capacity is None: + if shuffle: + queue_capacity = 4 * len(x) + else: + queue_capacity = len(x) + min_after_dequeue = max(queue_capacity / 4, 1) + + def input_fn(): + """Pandas input function.""" + queue = feeding_functions._enqueue_data( # pylint: disable=protected-access + x, + queue_capacity, + shuffle=shuffle, + min_after_dequeue=min_after_dequeue, + num_threads=num_threads, + enqueue_size=batch_size, + num_epochs=num_epochs) + if num_epochs is None: + features = queue.dequeue_many(batch_size) + else: + features = queue.dequeue_up_to(batch_size) + assert len(features) == len(x.columns) + 1, ('Features should have one ' + 'extra element for the index.') + features = features[1:] + features = dict(zip(list(x.columns), features)) + if y is not None: + target = features.pop(target_column) + return features, target + return features + return input_fn diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py new file mode 100644 index 0000000000..2e1fee4dd8 --- /dev/null +++ b/tensorflow/python/estimator/inputs/pandas_io_test.py @@ -0,0 +1,234 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for pandas_io.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.estimator.inputs import pandas_io +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + +if HAS_PANDAS: + # pylint: disable=g-import-not-at-top + import pandas as pd + + +class PandasIoTest(test.TestCase): + + def makeTestDataFrame(self): + index = np.arange(100, 104) + a = np.arange(4) + b = np.arange(32, 36) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y = pd.Series(np.arange(-32, -28), index=index) + return x, y + + def callInputFnOnce(self, input_fn, session): + results = input_fn() + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + result_values = session.run(results) + coord.request_stop() + coord.join(threads) + return result_values + + def testPandasInputFn_IndexMismatch(self): + if not HAS_PANDAS: + return + x, _ = self.makeTestDataFrame() + y_noindex = pd.Series(np.arange(-32, -28)) + with self.assertRaises(ValueError): + pandas_io.pandas_input_fn( + x, y_noindex, batch_size=2, shuffle=False, num_epochs=1) + + def testPandasInputFn_ProducesExpectedOutputs(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, target = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(target, [-32, -31]) + + def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + index = np.arange(100, 102) + a = np.arange(2) + b = np.arange(32, 34) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y = pd.Series(np.arange(-32, -30), index=index) + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=128, shuffle=False, num_epochs=2) + + results = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + features, target = session.run(results) + self.assertAllEqual(features['a'], [0, 1, 0, 1]) + self.assertAllEqual(features['b'], [32, 33, 32, 33]) + self.assertAllEqual(target, [-32, -31, -32, -31]) + + with self.assertRaises(errors.OutOfRangeError): + session.run(results) + + coord.request_stop() + coord.join(threads) + + def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + index = np.arange(100, 105) + a = np.arange(5) + b = np.arange(32, 37) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y = pd.Series(np.arange(-32, -27), index=index) + + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + results = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + features, target = session.run(results) + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(target, [-32, -31]) + + features, target = session.run(results) + self.assertAllEqual(features['a'], [2, 3]) + self.assertAllEqual(features['b'], [34, 35]) + self.assertAllEqual(target, [-30, -29]) + + features, target = session.run(results) + self.assertAllEqual(features['a'], [4]) + self.assertAllEqual(features['b'], [36]) + self.assertAllEqual(target, [-28]) + + with self.assertRaises(errors.OutOfRangeError): + session.run(results) + + coord.request_stop() + coord.join(threads) + + def testPandasInputFn_OnlyX(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, _ = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y=None, batch_size=2, shuffle=False, num_epochs=1) + + features = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + + def testPandasInputFn_ExcludesIndex(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, _ = self.callInputFnOnce(input_fn, session) + + self.assertFalse('index' in features) + + def assertInputsCallableNTimes(self, input_fn, session, n): + inputs = input_fn() + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + for _ in range(n): + session.run(inputs) + with self.assertRaises(errors.OutOfRangeError): + session.run(inputs) + coord.request_stop() + coord.join(threads) + + def testPandasInputFn_RespectsEpoch_NoShuffle(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=4, shuffle=False, num_epochs=1) + + self.assertInputsCallableNTimes(input_fn, session, 1) + + def testPandasInputFn_RespectsEpoch_WithShuffle(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=4, shuffle=True, num_epochs=1) + + self.assertInputsCallableNTimes(input_fn, session, 1) + + def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2) + + self.assertInputsCallableNTimes(input_fn, session, 4) + + def testPandasInputFn_RespectsEpochUnevenBatches(self): + if not HAS_PANDAS: + return + x, y = self.makeTestDataFrame() + with self.test_session() as session: + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=3, shuffle=False, num_epochs=1) + + # Before the last batch, only one element of the epoch should remain. + self.assertInputsCallableNTimes(input_fn, session, 2) + + def testPandasInputFn_Idempotent(self): + if not HAS_PANDAS: + return + x, y = self.makeTestDataFrame() + for _ in range(2): + pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1)() + for _ in range(2): + pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=True, num_epochs=1)() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/inputs/queues/__init__.py b/tensorflow/python/estimator/inputs/queues/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/python/estimator/inputs/queues/__init__.py diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py new file mode 100644 index 0000000000..aa39958559 --- /dev/null +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py @@ -0,0 +1,345 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions for enqueuing data from arrays and pandas `DataFrame`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import random +import numpy as np + +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.estimator.inputs.queues import feeding_queue_runner as fqr +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import queue_runner + +if HAS_PANDAS: + # pylint: disable=g-import-not-at-top + import pandas as pd + + +def _get_integer_indices_for_next_batch( + batch_indices_start, batch_size, epoch_end, array_length, + current_epoch, total_epochs): + """Returns the integer indices for next batch. + + If total epochs is not None and current epoch is the final epoch, the end + index of the next batch should not exceed the `epoch_end` (i.e., the final + batch might not have size `batch_size` to avoid overshooting the last epoch). + + Args: + batch_indices_start: Integer, the index to start next batch. + batch_size: Integer, size of batches to return. + epoch_end: Integer, the end index of the epoch. The epoch could start from a + random position, so `epoch_end` provides the end index for that. + array_length: Integer, the length of the array. + current_epoch: Integer, the epoch number has been emitted. + total_epochs: Integer or `None`, the total number of epochs to emit. If + `None` will run forever. + + Returns: + A tuple of a list with integer indices for next batch and `current_epoch` + value after the next batch. + + Raises: + OutOfRangeError if `current_epoch` is not less than `total_epochs`. + + """ + if total_epochs is not None and current_epoch >= total_epochs: + raise errors.OutOfRangeError(None, None, + "Already emitted %s epochs." % current_epoch) + + batch_indices_end = batch_indices_start + batch_size + batch_indices = [j % array_length for j in + range(batch_indices_start, batch_indices_end)] + epoch_end_indices = [i for i, x in enumerate(batch_indices) if x == epoch_end] + current_epoch += len(epoch_end_indices) + + if total_epochs is None or current_epoch < total_epochs: + return (batch_indices, current_epoch) + + # Now we might have emitted more data for expected epochs. Need to trim. + final_epoch_end_inclusive = epoch_end_indices[ + -(current_epoch - total_epochs + 1)] + batch_indices = batch_indices[:final_epoch_end_inclusive + 1] + + return (batch_indices, total_epochs) + + +class _ArrayFeedFn(object): + """Creates feed dictionaries from numpy arrays.""" + + def __init__(self, + placeholders, + array, + batch_size, + random_start=False, + seed=None, + num_epochs=None): + if len(placeholders) != 2: + raise ValueError("_array_feed_fn expects 2 placeholders; got {}.".format( + len(placeholders))) + self._placeholders = placeholders + self._array = array + self._max = len(array) + self._batch_size = batch_size + self._num_epochs = num_epochs + self._epoch = 0 + random.seed(seed) + self._trav = random.randrange(self._max) if random_start else 0 + self._epoch_end = (self._trav - 1) % self._max + + def __call__(self): + integer_indexes, self._epoch = _get_integer_indices_for_next_batch( + batch_indices_start=self._trav, + batch_size=self._batch_size, + epoch_end=self._epoch_end, + array_length=self._max, + current_epoch=self._epoch, + total_epochs=self._num_epochs) + + self._trav = (integer_indexes[-1] + 1) % self._max + return { + self._placeholders[0]: integer_indexes, + self._placeholders[1]: self._array[integer_indexes] + } + + +class _OrderedDictNumpyFeedFn(object): + """Creates feed dictionaries from `OrderedDict`s of numpy arrays.""" + + def __init__(self, + placeholders, + ordered_dict_of_arrays, + batch_size, + random_start=False, + seed=None, + num_epochs=None): + if len(placeholders) != len(ordered_dict_of_arrays) + 1: + raise ValueError("Expected {} placeholders; got {}.".format( + len(ordered_dict_of_arrays), len(placeholders))) + self._index_placeholder = placeholders[0] + self._col_placeholders = placeholders[1:] + self._ordered_dict_of_arrays = ordered_dict_of_arrays + self._max = len(next(iter(ordered_dict_of_arrays.values()))) + for _, v in ordered_dict_of_arrays.items(): + if len(v) != self._max: + raise ValueError("Array lengths must match.") + self._batch_size = batch_size + self._num_epochs = num_epochs + self._epoch = 0 + random.seed(seed) + self._trav = random.randrange(self._max) if random_start else 0 + self._epoch_end = (self._trav - 1) % self._max + + def __call__(self): + integer_indexes, self._epoch = _get_integer_indices_for_next_batch( + batch_indices_start=self._trav, + batch_size=self._batch_size, + epoch_end=self._epoch_end, + array_length=self._max, + current_epoch=self._epoch, + total_epochs=self._num_epochs) + + self._trav = (integer_indexes[-1] + 1) % self._max + feed_dict = {self._index_placeholder: integer_indexes} + cols = [ + column[integer_indexes] + for column in self._ordered_dict_of_arrays.values() + ] + feed_dict.update(dict(zip(self._col_placeholders, cols))) + return feed_dict + + +class _PandasFeedFn(object): + """Creates feed dictionaries from pandas `DataFrames`.""" + + def __init__(self, + placeholders, + dataframe, + batch_size, + random_start=False, + seed=None, + num_epochs=None): + if len(placeholders) != len(dataframe.columns) + 1: + raise ValueError("Expected {} placeholders; got {}.".format( + len(dataframe.columns), len(placeholders))) + self._index_placeholder = placeholders[0] + self._col_placeholders = placeholders[1:] + self._dataframe = dataframe + self._max = len(dataframe) + self._batch_size = batch_size + self._num_epochs = num_epochs + self._epoch = 0 + random.seed(seed) + self._trav = random.randrange(self._max) if random_start else 0 + self._epoch_end = (self._trav - 1) % self._max + + def __call__(self): + integer_indexes, self._epoch = _get_integer_indices_for_next_batch( + batch_indices_start=self._trav, + batch_size=self._batch_size, + epoch_end=self._epoch_end, + array_length=self._max, + current_epoch=self._epoch, + total_epochs=self._num_epochs) + + self._trav = (integer_indexes[-1] + 1) % self._max + result = self._dataframe.iloc[integer_indexes] + cols = [result[col].values for col in result.columns] + feed_dict = dict(zip(self._col_placeholders, cols)) + feed_dict[self._index_placeholder] = result.index.values + return feed_dict + + +def _enqueue_data(data, + capacity, + shuffle=False, + min_after_dequeue=None, + num_threads=1, + seed=None, + name="enqueue_input", + enqueue_size=1, + num_epochs=None): + """Creates a queue filled from a numpy array or pandas `DataFrame`. + + Returns a queue filled with the rows of the given (`OrderedDict` of) array + or `DataFrame`. In the case of a pandas `DataFrame`, the first enqueued + `Tensor` corresponds to the index of the `DataFrame`. For (`OrderedDict` of) + numpy arrays, the first enqueued `Tensor` contains the row number. + + Args: + data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or pandas + `DataFrame` that will be read into the queue. + capacity: the capacity of the queue. + shuffle: whether or not to shuffle the rows of the array. + min_after_dequeue: minimum number of elements that can remain in the queue + after a dequeue operation. Only used when `shuffle` is true. If not set, + defaults to `capacity` / 4. + num_threads: number of threads used for reading and enqueueing. + seed: used to seed shuffling and reader starting points. + name: a scope name identifying the data. + enqueue_size: the number of rows to enqueue per step. + num_epochs: limit enqueuing to a specified number of epochs, if provided. + + Returns: + A queue filled with the rows of the given (`OrderedDict` of) array or + `DataFrame`. + + Raises: + TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy + arrays or a numpy `ndarray`. + """ + with ops.name_scope(name): + if isinstance(data, np.ndarray): + types = [dtypes.int64, dtypes.as_dtype(data.dtype)] + queue_shapes = [(), data.shape[1:]] + get_feed_fn = _ArrayFeedFn + elif isinstance(data, collections.OrderedDict): + types = [dtypes.int64] + [ + dtypes.as_dtype(col.dtype) for col in data.values() + ] + queue_shapes = [()] + [col.shape[1:] for col in data.values()] + get_feed_fn = _OrderedDictNumpyFeedFn + elif HAS_PANDAS and isinstance(data, pd.DataFrame): + types = [ + dtypes.as_dtype(dt) for dt in [data.index.dtype] + list(data.dtypes) + ] + queue_shapes = [() for _ in types] + get_feed_fn = _PandasFeedFn + else: + raise TypeError( + "data must be either a numpy array or pandas DataFrame if pandas is " + "installed; got {}".format(type(data).__name__)) + + # TODO(jamieas): TensorBoard warnings for all warnings below once available. + + if num_threads > 1 and num_epochs is not None: + logging.warning( + "enqueue_data was called with num_epochs and num_threads > 1. " + "num_epochs is applied per thread, so this will produce more " + "epochs than you probably intend. " + "If you want to limit epochs, use one thread.") + + if shuffle and num_threads > 1 and num_epochs is not None: + logging.warning( + "enqueue_data was called with shuffle=True, num_threads > 1, and " + "num_epochs. This will create multiple threads, all reading the " + "array/dataframe in order adding to the same shuffling queue; the " + "results will likely not be sufficiently shuffled.") + + if not shuffle and num_threads > 1: + logging.warning( + "enqueue_data was called with shuffle=False and num_threads > 1. " + "This will create multiple threads, all reading the " + "array/dataframe in order. If you want examples read in order, use" + " one thread; if you want multiple threads, enable shuffling.") + + if shuffle: + min_after_dequeue = int(capacity / 4 if min_after_dequeue is None else + min_after_dequeue) + queue = data_flow_ops.RandomShuffleQueue( + capacity, + min_after_dequeue, + dtypes=types, + shapes=queue_shapes, + seed=seed) + else: + min_after_dequeue = 0 # just for the summary text + queue = data_flow_ops.FIFOQueue( + capacity, dtypes=types, shapes=queue_shapes) + + enqueue_ops = [] + feed_fns = [] + + for i in range(num_threads): + # Note the placeholders have no shapes, so they will accept any + # enqueue_size. enqueue_many below will break them up. + placeholders = [array_ops.placeholder(t) for t in types] + + enqueue_ops.append(queue.enqueue_many(placeholders)) + seed_i = None if seed is None else (i + 1) * seed + feed_fns.append( + get_feed_fn( + placeholders, + data, + enqueue_size, + random_start=shuffle, + seed=seed_i, + num_epochs=num_epochs)) + + runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access + queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns) + queue_runner.add_queue_runner(runner) + + full = (math_ops.cast( + math_ops.maximum(0, queue.size() - min_after_dequeue), + dtypes.float32) * (1. / (capacity - min_after_dequeue))) + # Note that name contains a '/' at the end so we intentionally do not place + # a '/' after %s below. + summary_name = ("queue/%sfraction_over_%d_of_%d_full" % + (queue.name, min_after_dequeue, + capacity - min_after_dequeue)) + summary.scalar(summary_name, full) + return queue diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py new file mode 100644 index 0000000000..ad27d990ea --- /dev/null +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py @@ -0,0 +1,290 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests feeding functions using arrays and `DataFrames`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.estimator.inputs.queues import feeding_functions as ff +from tensorflow.python.platform import test + +if HAS_PANDAS: + # pylint: disable=g-import-not-at-top + import pandas as pd + + +def vals_to_list(a): + return { + key: val.tolist() if isinstance(val, np.ndarray) else val + for key, val in a.items() + } + + +class _FeedingFunctionsTestCase(test.TestCase): + """Tests for feeding functions.""" + + def testArrayFeedFnBatchOne(self): + array = np.arange(32).reshape([16, 2]) + placeholders = ["index_placeholder", "value_placeholder"] + aff = ff._ArrayFeedFn(placeholders, array, 1) + + # cycle around a couple times + for x in range(0, 100): + i = x % 16 + expected = { + "index_placeholder": [i], + "value_placeholder": [[2 * i, 2 * i + 1]] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testArrayFeedFnBatchFive(self): + array = np.arange(32).reshape([16, 2]) + placeholders = ["index_placeholder", "value_placeholder"] + aff = ff._ArrayFeedFn(placeholders, array, 5) + + # cycle around a couple times + for _ in range(0, 101, 2): + aff() + + expected = { + "index_placeholder": [15, 0, 1, 2, 3], + "value_placeholder": [[30, 31], [0, 1], [2, 3], [4, 5], [6, 7]] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testArrayFeedFnBatchTwoWithOneEpoch(self): + array = np.arange(5) + 10 + placeholders = ["index_placeholder", "value_placeholder"] + aff = ff._ArrayFeedFn(placeholders, array, batch_size=2, num_epochs=1) + + expected = { + "index_placeholder": [0, 1], + "value_placeholder": [10, 11] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [2, 3], + "value_placeholder": [12, 13] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [4], + "value_placeholder": [14] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testArrayFeedFnBatchOneHundred(self): + array = np.arange(32).reshape([16, 2]) + placeholders = ["index_placeholder", "value_placeholder"] + aff = ff._ArrayFeedFn(placeholders, array, 100) + + expected = { + "index_placeholder": + list(range(0, 16)) * 6 + list(range(0, 4)), + "value_placeholder": + np.arange(32).reshape([16, 2]).tolist() * 6 + + [[0, 1], [2, 3], [4, 5], [6, 7]] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testArrayFeedFnBatchOneHundredWithSmallerArrayAndMultipleEpochs(self): + array = np.arange(2) + 10 + placeholders = ["index_placeholder", "value_placeholder"] + aff = ff._ArrayFeedFn(placeholders, array, batch_size=100, num_epochs=2) + + expected = { + "index_placeholder": [0, 1, 0, 1], + "value_placeholder": [10, 11, 10, 11], + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testPandasFeedFnBatchOne(self): + if not HAS_PANDAS: + return + array1 = np.arange(32, 64) + array2 = np.arange(64, 96) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128)) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._PandasFeedFn(placeholders, df, 1) + + # cycle around a couple times + for x in range(0, 100): + i = x % 32 + expected = { + "index_placeholder": [i + 96], + "a_placeholder": [32 + i], + "b_placeholder": [64 + i] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testPandasFeedFnBatchFive(self): + if not HAS_PANDAS: + return + array1 = np.arange(32, 64) + array2 = np.arange(64, 96) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128)) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._PandasFeedFn(placeholders, df, 5) + + # cycle around a couple times + for _ in range(0, 101, 2): + aff() + + expected = { + "index_placeholder": [127, 96, 97, 98, 99], + "a_placeholder": [63, 32, 33, 34, 35], + "b_placeholder": [95, 64, 65, 66, 67] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testPandasFeedFnBatchTwoWithOneEpoch(self): + if not HAS_PANDAS: + return + array1 = np.arange(32, 37) + array2 = np.arange(64, 69) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 101)) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._PandasFeedFn(placeholders, df, batch_size=2, num_epochs=1) + + expected = { + "index_placeholder": [96, 97], + "a_placeholder": [32, 33], + "b_placeholder": [64, 65] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [98, 99], + "a_placeholder": [34, 35], + "b_placeholder": [66, 67] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [100], + "a_placeholder": [36], + "b_placeholder": [68] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testPandasFeedFnBatchOneHundred(self): + if not HAS_PANDAS: + return + array1 = np.arange(32, 64) + array2 = np.arange(64, 96) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128)) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._PandasFeedFn(placeholders, df, 100) + + expected = { + "index_placeholder": list(range(96, 128)) * 3 + list(range(96, 100)), + "a_placeholder": list(range(32, 64)) * 3 + list(range(32, 36)), + "b_placeholder": list(range(64, 96)) * 3 + list(range(64, 68)) + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testPandasFeedFnBatchOneHundredWithSmallDataArrayAndMultipleEpochs(self): + if not HAS_PANDAS: + return + array1 = np.arange(32, 34) + array2 = np.arange(64, 66) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 98)) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._PandasFeedFn(placeholders, df, batch_size=100, num_epochs=2) + + expected = { + "index_placeholder": [96, 97, 96, 97], + "a_placeholder": [32, 33, 32, 33], + "b_placeholder": [64, 65, 64, 65] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testOrderedDictNumpyFeedFnBatchTwoWithOneEpoch(self): + a = np.arange(32, 37) + b = np.arange(64, 69) + x = {"a": a, "b": b} + ordered_dict_x = collections.OrderedDict( + sorted(x.items(), key=lambda t: t[0])) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._OrderedDictNumpyFeedFn( + placeholders, ordered_dict_x, batch_size=2, num_epochs=1) + + expected = { + "index_placeholder": [0, 1], + "a_placeholder": [32, 33], + "b_placeholder": [64, 65] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [2, 3], + "a_placeholder": [34, 35], + "b_placeholder": [66, 67] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + expected = { + "index_placeholder": [4], + "a_placeholder": [36], + "b_placeholder": [68] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + def testOrderedDictNumpyFeedFnLargeBatchWithSmallArrayAndMultipleEpochs(self): + a = np.arange(32, 34) + b = np.arange(64, 66) + x = {"a": a, "b": b} + ordered_dict_x = collections.OrderedDict( + sorted(x.items(), key=lambda t: t[0])) + placeholders = ["index_placeholder", "a_placeholder", "b_placeholder"] + aff = ff._OrderedDictNumpyFeedFn( + placeholders, ordered_dict_x, batch_size=100, num_epochs=2) + + expected = { + "index_placeholder": [0, 1, 0, 1], + "a_placeholder": [32, 33, 32, 33], + "b_placeholder": [64, 65, 64, 65] + } + actual = aff() + self.assertEqual(expected, vals_to_list(actual)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/estimator/inputs/queues/feeding_queue_runner.py b/tensorflow/python/estimator/inputs/queues/feeding_queue_runner.py new file mode 100644 index 0000000000..afbcab596a --- /dev/null +++ b/tensorflow/python/estimator/inputs/queues/feeding_queue_runner.py @@ -0,0 +1,180 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `QueueRunner` that takes a feed function as an argument.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.python.framework import errors +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import queue_runner as qr + + +class _FeedingQueueRunner(qr.QueueRunner): + """A queue runner that allows the feeding of values such as numpy arrays.""" + + def __init__(self, queue=None, enqueue_ops=None, close_op=None, + cancel_op=None, feed_fns=None, + queue_closed_exception_types=None): + """Initialize the queue runner. + + For further documentation, see `queue_runner.py`. Note that + `FeedingQueueRunner` does not support construction from protobuffer nor + serialization to protobuffer. + + Args: + queue: A `Queue`. + enqueue_ops: List of enqueue ops to run in threads later. + close_op: Op to close the queue. Pending enqueue ops are preserved. + cancel_op: Op to close the queue and cancel pending enqueue ops. + feed_fns: a list of functions that return a dictionary mapping fed + `Tensor`s to values. Must be the same length as `enqueue_ops`. + queue_closed_exception_types: Optional tuple of Exception types that + indicate that the queue has been closed when raised during an enqueue + operation. Defaults to + `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`. + + Raises: + ValueError: `feed_fns` is not `None` and has different length than + `enqueue_ops`. + """ + if queue_closed_exception_types is None: + queue_closed_exception_types = ( + errors.OutOfRangeError, errors.CancelledError) + super(_FeedingQueueRunner, self).__init__( + queue, enqueue_ops, close_op, + cancel_op, queue_closed_exception_types=queue_closed_exception_types) + if feed_fns is None: + self._feed_fns = [None for _ in enqueue_ops] + else: + if len(feed_fns) != len(enqueue_ops): + raise ValueError( + "If feed_fns is not None, it must have the same length as " + "enqueue_ops.") + self._feed_fns = feed_fns + + # pylint: disable=broad-except + def _run(self, sess, enqueue_op, feed_fn, coord=None): + """Execute the enqueue op in a loop, close the queue in case of error. + + Args: + sess: A `Session`. + enqueue_op: The `Operation` to run. + feed_fn: the feed function to pass to `sess.run`. + coord: Optional `Coordinator` object for reporting errors and checking + for stop conditions. + + """ + # TODO(jamieas): Reduce code duplication with `QueueRunner`. + if coord: + coord.register_thread(threading.current_thread()) + decremented = False + try: + while True: + if coord and coord.should_stop(): + break + try: + feed_dict = None if feed_fn is None else feed_fn() + sess.run(enqueue_op, feed_dict=feed_dict) + except (errors.OutOfRangeError, errors.CancelledError): + # This exception indicates that a queue was closed. + with self._lock: + self._runs_per_session[sess] -= 1 + decremented = True + if self._runs_per_session[sess] == 0: + try: + sess.run(self._close_op) + except Exception as e: + # Intentionally ignore errors from close_op. + logging.vlog(1, "Ignored exception: %s", str(e)) + return + except Exception as e: + # This catches all other exceptions. + if coord: + coord.request_stop(e) + else: + logging.error("Exception in QueueRunner: %s", str(e)) + with self._lock: + self._exceptions_raised.append(e) + raise + finally: + # Make sure we account for all terminations: normal or errors. + if not decremented: + with self._lock: + self._runs_per_session[sess] -= 1 + + def create_threads(self, sess, coord=None, daemon=False, start=False): + """Create threads to run the enqueue ops for the given session. + + This method requires a session in which the graph was launched. It creates + a list of threads, optionally starting them. There is one thread for each + op passed in `enqueue_ops`. + + The `coord` argument is an optional coordinator, that the threads will use + to terminate together and report exceptions. If a coordinator is given, + this method starts an additional thread to close the queue when the + coordinator requests a stop. + + If previously created threads for the given session are still running, no + new threads will be created. + + Args: + sess: A `Session`. + coord: Optional `Coordinator` object for reporting errors and checking + stop conditions. + daemon: Boolean. If `True` make the threads daemon threads. + start: Boolean. If `True` starts the threads. If `False` the + caller must call the `start()` method of the returned threads. + + Returns: + A list of threads. + """ + with self._lock: + try: + if self._runs_per_session[sess] > 0: + # Already started: no new threads to return. + return [] + except KeyError: + # We haven't seen this session yet. + pass + self._runs_per_session[sess] = len(self._enqueue_ops) + self._exceptions_raised = [] + + ret_threads = [threading.Thread(target=self._run, + args=(sess, op, feed_fn, coord)) + for op, feed_fn in zip(self._enqueue_ops, self._feed_fns)] + if coord: + ret_threads.append(threading.Thread(target=self._close_on_stop, + args=(sess, self._cancel_op, coord))) + for t in ret_threads: + if daemon: + t.daemon = True + if start: + t.start() + return ret_threads + + def _init_from_proto(self, queue_runner_def): + raise NotImplementedError( + "{} does not support initialization from proto.".format(type( + self).__name__)) + + def to_proto(self): + raise NotImplementedError( + "{} does not support serialization to proto.".format(type( + self).__name__)) diff --git a/tensorflow/python/estimator/inputs/queues/feeding_queue_runner_test.py b/tensorflow/python/estimator/inputs/queues/feeding_queue_runner_test.py new file mode 100644 index 0000000000..c8d20970c5 --- /dev/null +++ b/tensorflow/python/estimator/inputs/queues/feeding_queue_runner_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests `FeedingQueueRunner` using arrays and `DataFrames`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.estimator.inputs.pandas_import import HAS_PANDAS +from tensorflow.python.estimator.inputs.queues import feeding_functions as ff +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + +if HAS_PANDAS: + # pylint: disable=g-import-not-at-top + import pandas as pd + + +def get_rows(array, row_indices): + rows = [array[i] for i in row_indices] + return np.vstack(rows) + + +class FeedingQueueRunnerTestCase(test.TestCase): + """Tests for `FeedingQueueRunner`.""" + + def testArrayFeeding(self): + with ops.Graph().as_default(): + array = np.arange(32).reshape([16, 2]) + q = ff._enqueue_data(array, capacity=100) + batch_size = 3 + dq_op = q.dequeue_many(batch_size) + with session.Session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + for i in range(100): + indices = [ + j % array.shape[0] + for j in range(batch_size * i, batch_size * (i + 1)) + ] + expected_dq = get_rows(array, indices) + dq = sess.run(dq_op) + np.testing.assert_array_equal(indices, dq[0]) + np.testing.assert_array_equal(expected_dq, dq[1]) + coord.request_stop() + coord.join(threads) + + def testArrayFeedingMultiThread(self): + with ops.Graph().as_default(): + array = np.arange(256).reshape([128, 2]) + q = ff._enqueue_data(array, capacity=128, num_threads=8, shuffle=True) + batch_size = 3 + dq_op = q.dequeue_many(batch_size) + with session.Session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + for _ in range(100): + dq = sess.run(dq_op) + indices = dq[0] + expected_dq = get_rows(array, indices) + np.testing.assert_array_equal(expected_dq, dq[1]) + coord.request_stop() + coord.join(threads) + + def testPandasFeeding(self): + if not HAS_PANDAS: + return + with ops.Graph().as_default(): + array1 = np.arange(32) + array2 = np.arange(32, 64) + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(64, 96)) + q = ff._enqueue_data(df, capacity=100) + batch_size = 5 + dq_op = q.dequeue_many(5) + with session.Session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + for i in range(100): + indices = [ + j % array1.shape[0] + for j in range(batch_size * i, batch_size * (i + 1)) + ] + expected_df_indices = df.index[indices] + expected_rows = df.iloc[indices] + dq = sess.run(dq_op) + np.testing.assert_array_equal(expected_df_indices, dq[0]) + for col_num, col in enumerate(df.columns): + np.testing.assert_array_equal(expected_rows[col].values, + dq[col_num + 1]) + coord.request_stop() + coord.join(threads) + + def testPandasFeedingMultiThread(self): + if not HAS_PANDAS: + return + with ops.Graph().as_default(): + array1 = np.arange(128, 256) + array2 = 2 * array1 + df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128)) + q = ff._enqueue_data(df, capacity=128, num_threads=8, shuffle=True) + batch_size = 5 + dq_op = q.dequeue_many(batch_size) + with session.Session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + for _ in range(100): + dq = sess.run(dq_op) + indices = dq[0] + expected_rows = df.iloc[indices] + for col_num, col in enumerate(df.columns): + np.testing.assert_array_equal(expected_rows[col].values, + dq[col_num + 1]) + coord.request_stop() + coord.join(threads) + + +if __name__ == "__main__": + test.main() |