diff options
author | 2018-01-18 10:10:07 -0800 | |
---|---|---|
committer | 2018-01-18 10:14:04 -0800 | |
commit | 54eb77f5411a75e43f5306113574d39381050b88 (patch) | |
tree | 77957e3c54824057af1d09e8bf7013953e0ef336 /tensorflow/contrib/learn | |
parent | b6ef205887bdf8794b8fb38056bec8706c123e58 (diff) |
Allow numpy_input_fn to take an ndarray representing a single input feature.
PiperOrigin-RevId: 182397583
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py | 280 |
2 files changed, 0 insertions, 294 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 0e87c36dcd..ee3611ca93 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -751,20 +751,6 @@ tf_py_test( ) py_test( - name = "numpy_io_test", - size = "small", - srcs = ["python/learn/learn_io/numpy_io_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":learn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - -py_test( name = "pandas_io_test", size = "small", srcs = ["python/learn/learn_io/pandas_io_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py deleted file mode 100644 index 6fe8de8705..0000000000 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for numpy_io.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.learn.python.learn.learn_io import numpy_io -from tensorflow.python.framework import errors -from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import queue_runner_impl - - -class NumpyIoTest(test.TestCase): - - def testNumpyInputFn(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - session.run([features, target]) - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self): - a = np.arange(2) * 1.0 - b = np.arange(32, 34) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=128, shuffle=False, num_epochs=2) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33, 32, 33]) - self.assertAllEqual(res[1], [-32, -31, -32, -31]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithZeroEpochs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=0) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self): - batch_size = 2 - a = np.arange(5) * 1.0 - b = np.arange(32, 37) - x = {'a': a, 'b': b} - y = np.arange(-32, -27) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 3]) - self.assertAllEqual(res[0]['b'], [34, 35]) - self.assertAllEqual(res[1], [-30, -29]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [4]) - self.assertAllEqual(res[0]['b'], [36]) - self.assertAllEqual(res[1], [-28]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self): - batch_size = 2 - a = np.arange(3) * 1.0 - b = np.arange(32, 35) - x = {'a': a, 'b': b} - y = np.arange(-32, -29) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=3) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 0]) - self.assertAllEqual(res[0]['b'], [34, 32]) - self.assertAllEqual(res[1], [-30, -32]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [1, 2]) - self.assertAllEqual(res[0]['b'], [33, 34]) - self.assertAllEqual(res[1], [-31, -30]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2]) - self.assertAllEqual(res[0]['b'], [34]) - self.assertAllEqual(res[1], [-30]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeLargerThanDataSize(self): - batch_size = 10 - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 2, 3]) - self.assertAllEqual(res[0]['b'], [32, 33, 34, 35]) - self.assertAllEqual(res[1], [-32, -31, -30, -29]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithDifferentDimensionsOfFeatures(self): - a = np.array([[1, 2], [3, 4]]) - b = np.array([5, 6]) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]]) - self.assertAllEqual(res[0]['b'], [5, 6]) - self.assertAllEqual(res[1], [-32, -31]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithXAsNonDict(self): - x = np.arange(32, 36) - y = np.arange(4) - with self.test_session(): - with self.assertRaisesRegexp(TypeError, 'x must be dict'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - def testNumpyInputFnWithTargetKeyAlreadyInX(self): - array = np.arange(32, 36) - x = {'__target_key__': array} - y = np.arange(4) - - with self.test_session(): - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - input_fn() - self.assertAllEqual(x['__target_key__'], array) - self.assertItemsEqual(x.keys(), ['__target_key__']) - - def testNumpyInputFnWithMismatchLengthOfInputs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - x_mismatch_length = {'a': np.arange(1), 'b': b} - y_longer_length = np.arange(10) - - with self.test_session(): - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x=x_mismatch_length, - y=None, - batch_size=2, - shuffle=False, - num_epochs=1) - failing_input_fn() - - -if __name__ == '__main__': - test.main() |