diff options
Diffstat (limited to 'tensorflow/python/estimator/inputs/numpy_io_test.py')
-rw-r--r-- | tensorflow/python/estimator/inputs/numpy_io_test.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py index 02df22b632..65eae7a7dc 100644 --- a/tensorflow/python/estimator/inputs/numpy_io_test.py +++ b/tensorflow/python/estimator/inputs/numpy_io_test.py @@ -239,6 +239,40 @@ class NumpyIoTest(test.TestCase): x, y, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() + def testNumpyInputFnWithXIsEmptyDict(self): + x = {} + y = np.arange(4) + with self.test_session(): + with self.assertRaisesRegexp(ValueError, 'x cannot be empty'): + failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) + failing_input_fn() + + def testNumpyInputFnWithYIsNone(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = None + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + features_tensor = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + feature = session.run(features_tensor) + self.assertEqual(len(feature), 2) + self.assertAllEqual(feature['a'], [0, 1]) + self.assertAllEqual(feature['b'], [32, 33]) + + session.run([features_tensor]) + with self.assertRaises(errors.OutOfRangeError): + session.run([features_tensor]) + + coord.request_stop() + coord.join(threads) + def testNumpyInputFnWithNonBoolShuffle(self): x = np.arange(32, 36) y = np.arange(4) @@ -285,6 +319,59 @@ class NumpyIoTest(test.TestCase): num_epochs=1) failing_input_fn() + def testNumpyInputFnWithYAsDict(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)} + + with self.test_session() as session: + input_fn = numpy_io.numpy_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + features_tensor, targets_tensor = input_fn() + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + features, targets = session.run([features_tensor, targets_tensor]) + self.assertEqual(len(features), 2) + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertEqual(len(targets), 2) + self.assertAllEqual(targets['y1'], [-32, -31]) + self.assertAllEqual(targets['y2'], [32, 31]) + + session.run([features_tensor, targets_tensor]) + with self.assertRaises(errors.OutOfRangeError): + session.run([features_tensor, targets_tensor]) + + coord.request_stop() + coord.join(threads) + + def testNumpyInputFnWithYIsEmptyDict(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = {} + with self.test_session(): + with self.assertRaisesRegexp(ValueError, 'y cannot be empty'): + failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) + failing_input_fn() + + def testNumpyInputFnWithDuplicateKeysInXAndY(self): + a = np.arange(4) * 1.0 + b = np.arange(32, 36) + x = {'a': a, 'b': b} + y = {'y1': np.arange(-32, -28), + 'a': a, + 'y2': np.arange(32, 28, -1), + 'b': b} + with self.test_session(): + with self.assertRaisesRegexp( + ValueError, '2 duplicate keys are found in both x and y'): + failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) + failing_input_fn() + if __name__ == '__main__': test.main() |