aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/BUILD12
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/__init__.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py58
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py73
4 files changed, 144 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 8161ac73c3..203bbb81c3 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -615,6 +615,18 @@ py_test(
)
py_test(
+ name = "pandas_io_test",
+ size = "small",
+ srcs = ["python/learn/learn_io/pandas_io_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "export_test",
size = "small",
srcs = ["python/learn/utils/export_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py
index 70b420e10b..c3e2b56a6e 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py
@@ -34,3 +34,4 @@ from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pan
from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_labels
from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_matrix
from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import HAS_PANDAS
+from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import pandas_input_fn
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py
index 4500be5439..ee62e777cd 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py
@@ -19,6 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions
+
try:
# pylint: disable=g-import-not-at-top
import pandas as pd
@@ -117,3 +120,58 @@ def extract_pandas_labels(labels):
'float, or bool. Found: ' + ', '.join(error_report))
else:
return labels
+
+
+def pandas_input_fn(x, y=None, batch_size=128, num_epochs=None, shuffle=True,
+ queue_capacity=1000, num_threads=1, target_column='target',
+ index_column='index'):
+ """Returns input function that would feed pandas DataFrame into the model.
+
+ Note: If y's index doesn't match x's index exception will be raised.
+
+ Args:
+ x: pandas `DataFrame` object.
+ y: pandas `Series` object.
+ batch_size: int, size of batches to return.
+ num_epochs: int, number of epochs to iterate over data. If `None` will
+ run indefinetly.
+ shuffle: int, if shuffle the queue. Please make sure you don't shuffle at
+ prediction time.
+ queue_capacity: int, size of queue to accumulate.
+ num_threads: int, number of threads used for reading and enqueueing.
+ target_column: str, used to pack `y` into `x` DataFrame under this column.
+ index_column: str, name of the feature return with index.
+
+ Returns:
+ Function, that has signature of ()->(dict of `features`, `target`)
+
+ Raises:
+ ValueError: if `target_column` column is already in `x` DataFrame.
+ """
+ def input_fn():
+ """Pandas input function."""
+ if y is not None:
+ if target_column in x:
+ raise ValueError('Found already column \'%s\' in x, please change '
+ 'target_column to something else. Current columns '
+ 'in x: %s', target_column, x.columns)
+ if not np.array_equal(x.index, y.index):
+ raise ValueError('Index for x and y are mismatch, this will lead '
+ 'to missing values. Please make sure they match or '
+ 'use .reset_index() method.\n'
+ 'Index for x: %s\n'
+ 'Index for y: %s\n', x.index, y.index)
+ x[target_column] = y
+ queue = feeding_functions.enqueue_data(
+ x, queue_capacity, shuffle=shuffle, num_threads=num_threads,
+ enqueue_size=batch_size, num_epochs=num_epochs)
+ if num_epochs is None:
+ features = queue.dequeue_many(batch_size)
+ else:
+ features = queue.dequeue_up_to(batch_size)
+ features = dict(zip([index_column] + list(x.columns), features))
+ if y is not None:
+ target = features.pop(target_column)
+ return features, target
+ return features
+ return input_fn
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
new file mode 100644
index 0000000000..b18ebae011
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -0,0 +1,73 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for pandas_io."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.learn.python.learn.learn_io import pandas_io
+from tensorflow.python.framework import errors
+
+# pylint: disable=g-import-not-at-top
+try:
+ import pandas as pd
+ HAS_PANDAS = True
+except ImportError:
+ HAS_PANDAS = False
+
+
+class PandasIoTest(tf.test.TestCase):
+
+ def testPandasInputFn(self):
+ if not HAS_PANDAS:
+ return
+ index = np.arange(100, 104)
+ a = np.arange(4)
+ b = np.arange(32, 36)
+ x = pd.DataFrame({'a': a, 'b': b}, index=index)
+ y_noindex = pd.Series(np.arange(-32, -28))
+ y = pd.Series(np.arange(-32, -28), index=index)
+ with self.test_session() as session:
+ with self.assertRaises(ValueError):
+ failing_input_fn = pandas_io.pandas_input_fn(
+ x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
+ failing_input_fn()
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features, target = input_fn()
+
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(session, coord=coord)
+
+ res = session.run([features, target])
+ self.assertAllEqual(res[0]['index'], [100, 101])
+ self.assertAllEqual(res[0]['a'], [0, 1])
+ self.assertAllEqual(res[0]['b'], [32, 33])
+ self.assertAllEqual(res[1], [-32, -31])
+
+ session.run([features, target])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features, target])
+
+ coord.request_stop()
+ coord.join(threads)
+
+
+if __name__ == '__main__':
+ tf.test.main()