diff options
4 files changed, 144 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 8161ac73c3..203bbb81c3 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -615,6 +615,18 @@ py_test( ) py_test( + name = "pandas_io_test", + size = "small", + srcs = ["python/learn/learn_io/pandas_io_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "export_test", size = "small", srcs = ["python/learn/utils/export_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 70b420e10b..c3e2b56a6e 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -34,3 +34,4 @@ from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pan from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_labels from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_matrix from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import HAS_PANDAS +from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import pandas_input_fn diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index 4500be5439..ee62e777cd 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -19,6 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np +from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions + try: # pylint: disable=g-import-not-at-top import pandas as pd @@ -117,3 +120,58 @@ def extract_pandas_labels(labels): 'float, or bool. Found: ' + ', '.join(error_report)) else: return labels + + +def pandas_input_fn(x, y=None, batch_size=128, num_epochs=None, shuffle=True, + queue_capacity=1000, num_threads=1, target_column='target', + index_column='index'): + """Returns input function that would feed pandas DataFrame into the model. + + Note: If y's index doesn't match x's index exception will be raised. + + Args: + x: pandas `DataFrame` object. + y: pandas `Series` object. + batch_size: int, size of batches to return. + num_epochs: int, number of epochs to iterate over data. If `None` will + run indefinetly. + shuffle: int, if shuffle the queue. Please make sure you don't shuffle at + prediction time. + queue_capacity: int, size of queue to accumulate. + num_threads: int, number of threads used for reading and enqueueing. + target_column: str, used to pack `y` into `x` DataFrame under this column. + index_column: str, name of the feature return with index. + + Returns: + Function, that has signature of ()->(dict of `features`, `target`) + + Raises: + ValueError: if `target_column` column is already in `x` DataFrame. + """ + def input_fn(): + """Pandas input function.""" + if y is not None: + if target_column in x: + raise ValueError('Found already column \'%s\' in x, please change ' + 'target_column to something else. Current columns ' + 'in x: %s', target_column, x.columns) + if not np.array_equal(x.index, y.index): + raise ValueError('Index for x and y are mismatch, this will lead ' + 'to missing values. Please make sure they match or ' + 'use .reset_index() method.\n' + 'Index for x: %s\n' + 'Index for y: %s\n', x.index, y.index) + x[target_column] = y + queue = feeding_functions.enqueue_data( + x, queue_capacity, shuffle=shuffle, num_threads=num_threads, + enqueue_size=batch_size, num_epochs=num_epochs) + if num_epochs is None: + features = queue.dequeue_many(batch_size) + else: + features = queue.dequeue_up_to(batch_size) + features = dict(zip([index_column] + list(x.columns), features)) + if y is not None: + target = features.pop(target_column) + return features, target + return features + return input_fn diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py new file mode 100644 index 0000000000..b18ebae011 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py @@ -0,0 +1,73 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for pandas_io.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.learn.python.learn.learn_io import pandas_io +from tensorflow.python.framework import errors + +# pylint: disable=g-import-not-at-top +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + + +class PandasIoTest(tf.test.TestCase): + + def testPandasInputFn(self): + if not HAS_PANDAS: + return + index = np.arange(100, 104) + a = np.arange(4) + b = np.arange(32, 36) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y_noindex = pd.Series(np.arange(-32, -28)) + y = pd.Series(np.arange(-32, -28), index=index) + with self.test_session() as session: + with self.assertRaises(ValueError): + failing_input_fn = pandas_io.pandas_input_fn( + x, y_noindex, batch_size=2, shuffle=False, num_epochs=1) + failing_input_fn() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + features, target = input_fn() + + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(session, coord=coord) + + res = session.run([features, target]) + self.assertAllEqual(res[0]['index'], [100, 101]) + self.assertAllEqual(res[0]['a'], [0, 1]) + self.assertAllEqual(res[0]['b'], [32, 33]) + self.assertAllEqual(res[1], [-32, -31]) + + session.run([features, target]) + with self.assertRaises(errors.OutOfRangeError): + session.run([features, target]) + + coord.request_stop() + coord.join(threads) + + +if __name__ == '__main__': + tf.test.main() |