aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/pandas_io_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/inputs/pandas_io_test.py')
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io_test.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index dcecf6dd61..6f13bc95d2 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase):
y = pd.Series(np.arange(-32, -28), index=index)
return x, y
+ def makeTestDataFrameWithYAsDataFrame(self):
+ index = np.arange(100, 104)
+ a = np.arange(4)
+ b = np.arange(32, 36)
+ a_label = np.arange(10, 14)
+ b_label = np.arange(50, 54)
+ x = pd.DataFrame({'a': a, 'b': b}, index=index)
+ y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)
+ return x, y
+
def callInputFnOnce(self, input_fn, session):
results = input_fn()
coord = coordinator.Coordinator()
@@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase):
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
+ def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):
+ if not HAS_PANDAS:
+ return
+
+ x, y = self.makeTestDataFrame()
+
+ with self.assertRaisesRegexp(TypeError,
+ 'target_column must be a string type'):
+ pandas_io.pandas_input_fn(x, y, batch_size=2,
+ shuffle=False,
+ num_epochs=1,
+ target_column=['one', 'two'])
+
def testPandasInputFn_NonBoolShuffle(self):
if not HAS_PANDAS:
return
@@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase):
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(target, [-32, -31])
+ def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a_target'], [10, 11])
+ self.assertAllEqual(targets['b_target'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['b'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['a_n'], [50, 51])
+
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return