aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/pandas_io.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/inputs/pandas_io.py')
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index 57f8e5fd6a..616bcb410f 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -18,6 +18,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+import uuid
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
@@ -35,6 +37,22 @@ except ImportError:
HAS_PANDAS = False
+def _get_unique_target_key(features, target_column_name):
+ """Returns a key that does not exist in the input DataFrame `features`.
+
+ Args:
+ features: DataFrame
+ target_column_name: Name of the target column as a `str`
+
+ Returns:
+ A unique key that can be used to insert the target into
+ features.
+ """
+ if target_column_name in features:
+ target_column_name += '_' + str(uuid.uuid4())
+ return target_column_name
+
+
@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
@@ -50,7 +68,7 @@ def pandas_input_fn(x,
Args:
x: pandas `DataFrame` object.
- y: pandas `Series` object. `None` if absent.
+ y: pandas `Series` object or `DataFrame`. `None` if absent.
batch_size: int, size of batches to return.
num_epochs: int, number of epochs to iterate over data. If not `None`,
read attempts that would exceed this value will raise `OutOfRangeError`.
@@ -60,7 +78,8 @@ def pandas_input_fn(x,
num_threads: Integer, number of threads used for reading and enqueueing. In
order to have predicted and repeatable order of reading and enqueueing,
such as in prediction and evaluation mode, `num_threads` should be 1.
- target_column: str, name to give the target column `y`.
+ target_column: str, name to give the target column `y`. This parameter
+ is not used when `y` is a `DataFrame`.
Returns:
Function, that has signature of ()->(dict of `features`, `target`)
@@ -79,6 +98,9 @@ def pandas_input_fn(x,
'(it is recommended to set it as True for training); '
'got {}'.format(shuffle))
+ if not isinstance(target_column, six.string_types):
+ raise TypeError('target_column must be a string type')
+
x = x.copy()
if y is not None:
if target_column in x:
@@ -88,7 +110,13 @@ def pandas_input_fn(x,
if not np.array_equal(x.index, y.index):
raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
'Index for y: %s\n' % (x.index, y.index))
- x[target_column] = y
+ if isinstance(y, pd.DataFrame):
+ y_columns = [(column, _get_unique_target_key(x, column))
+ for column in list(y)]
+ target_column = [v for _, v in y_columns]
+ x[target_column] = y
+ else:
+ x[target_column] = y
# TODO(mdan): These are memory copies. We probably don't need 4x slack space.
# The sizes below are consistent with what I've seen elsewhere.
@@ -118,7 +146,12 @@ def pandas_input_fn(x,
features = features[1:]
features = dict(zip(list(x.columns), features))
if y is not None:
- target = features.pop(target_column)
+ if isinstance(target_column, list):
+ keys = [k for k, _ in y_columns]
+ values = [features.pop(column) for column in target_column]
+ target = {k: v for k, v in zip(keys, values)}
+ else:
+ target = features.pop(target_column)
return features, target
return features
return input_fn