aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/numpy_io_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/inputs/numpy_io_test.py')
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 02df22b632..65eae7a7dc 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -239,6 +239,40 @@ class NumpyIoTest(test.TestCase):
x, y, batch_size=2, shuffle=False, num_epochs=1)
failing_input_fn()
+ def testNumpyInputFnWithXIsEmptyDict(self):
+ x = {}
+ y = np.arange(4)
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, 'x cannot be empty'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
+ def testNumpyInputFnWithYIsNone(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = None
+
+ with self.test_session() as session:
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_tensor = input_fn()
+
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(session, coord=coord)
+
+ feature = session.run(features_tensor)
+ self.assertEqual(len(feature), 2)
+ self.assertAllEqual(feature['a'], [0, 1])
+ self.assertAllEqual(feature['b'], [32, 33])
+
+ session.run([features_tensor])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features_tensor])
+
+ coord.request_stop()
+ coord.join(threads)
+
def testNumpyInputFnWithNonBoolShuffle(self):
x = np.arange(32, 36)
y = np.arange(4)
@@ -285,6 +319,59 @@ class NumpyIoTest(test.TestCase):
num_epochs=1)
failing_input_fn()
+ def testNumpyInputFnWithYAsDict(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
+
+ with self.test_session() as session:
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_tensor, targets_tensor = input_fn()
+
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(session, coord=coord)
+
+ features, targets = session.run([features_tensor, targets_tensor])
+ self.assertEqual(len(features), 2)
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertEqual(len(targets), 2)
+ self.assertAllEqual(targets['y1'], [-32, -31])
+ self.assertAllEqual(targets['y2'], [32, 31])
+
+ session.run([features_tensor, targets_tensor])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features_tensor, targets_tensor])
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def testNumpyInputFnWithYIsEmptyDict(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {}
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
+ def testNumpyInputFnWithDuplicateKeysInXAndY(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {'y1': np.arange(-32, -28),
+ 'a': a,
+ 'y2': np.arange(32, 28, -1),
+ 'b': b}
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, '2 duplicate keys are found in both x and y'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
if __name__ == '__main__':
test.main()