aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py41
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io_test.py70
2 files changed, 107 insertions, 4 deletions
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index 57f8e5fd6a..616bcb410f 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -18,6 +18,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+import uuid
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
@@ -35,6 +37,22 @@ except ImportError:
HAS_PANDAS = False
+def _get_unique_target_key(features, target_column_name):
+ """Returns a key that does not exist in the input DataFrame `features`.
+
+ Args:
+ features: DataFrame
+ target_column_name: Name of the target column as a `str`
+
+ Returns:
+ A unique key that can be used to insert the target into
+ features.
+ """
+ if target_column_name in features:
+ target_column_name += '_' + str(uuid.uuid4())
+ return target_column_name
+
+
@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
@@ -50,7 +68,7 @@ def pandas_input_fn(x,
Args:
x: pandas `DataFrame` object.
- y: pandas `Series` object. `None` if absent.
+ y: pandas `Series` object or `DataFrame`. `None` if absent.
batch_size: int, size of batches to return.
num_epochs: int, number of epochs to iterate over data. If not `None`,
read attempts that would exceed this value will raise `OutOfRangeError`.
@@ -60,7 +78,8 @@ def pandas_input_fn(x,
num_threads: Integer, number of threads used for reading and enqueueing. In
order to have predicted and repeatable order of reading and enqueueing,
such as in prediction and evaluation mode, `num_threads` should be 1.
- target_column: str, name to give the target column `y`.
+ target_column: str, name to give the target column `y`. This parameter
+ is not used when `y` is a `DataFrame`.
Returns:
Function, that has signature of ()->(dict of `features`, `target`)
@@ -79,6 +98,9 @@ def pandas_input_fn(x,
'(it is recommended to set it as True for training); '
'got {}'.format(shuffle))
+ if not isinstance(target_column, six.string_types):
+ raise TypeError('target_column must be a string type')
+
x = x.copy()
if y is not None:
if target_column in x:
@@ -88,7 +110,13 @@ def pandas_input_fn(x,
if not np.array_equal(x.index, y.index):
raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
'Index for y: %s\n' % (x.index, y.index))
- x[target_column] = y
+ if isinstance(y, pd.DataFrame):
+ y_columns = [(column, _get_unique_target_key(x, column))
+ for column in list(y)]
+ target_column = [v for _, v in y_columns]
+ x[target_column] = y
+ else:
+ x[target_column] = y
# TODO(mdan): These are memory copies. We probably don't need 4x slack space.
# The sizes below are consistent with what I've seen elsewhere.
@@ -118,7 +146,12 @@ def pandas_input_fn(x,
features = features[1:]
features = dict(zip(list(x.columns), features))
if y is not None:
- target = features.pop(target_column)
+ if isinstance(target_column, list):
+ keys = [k for k, _ in y_columns]
+ values = [features.pop(column) for column in target_column]
+ target = {k: v for k, v in zip(keys, values)}
+ else:
+ target = features.pop(target_column)
return features, target
return features
return input_fn
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index dcecf6dd61..6f13bc95d2 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase):
y = pd.Series(np.arange(-32, -28), index=index)
return x, y
+ def makeTestDataFrameWithYAsDataFrame(self):
+ index = np.arange(100, 104)
+ a = np.arange(4)
+ b = np.arange(32, 36)
+ a_label = np.arange(10, 14)
+ b_label = np.arange(50, 54)
+ x = pd.DataFrame({'a': a, 'b': b}, index=index)
+ y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)
+ return x, y
+
def callInputFnOnce(self, input_fn, session):
results = input_fn()
coord = coordinator.Coordinator()
@@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase):
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
+ def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):
+ if not HAS_PANDAS:
+ return
+
+ x, y = self.makeTestDataFrame()
+
+ with self.assertRaisesRegexp(TypeError,
+ 'target_column must be a string type'):
+ pandas_io.pandas_input_fn(x, y, batch_size=2,
+ shuffle=False,
+ num_epochs=1,
+ target_column=['one', 'two'])
+
def testPandasInputFn_NonBoolShuffle(self):
if not HAS_PANDAS:
return
@@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase):
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(target, [-32, -31])
+ def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a_target'], [10, 11])
+ self.assertAllEqual(targets['b_target'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['b'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['a_n'], [50, 51])
+
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return