aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/BUILD14
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py280
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py68
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py91
4 files changed, 145 insertions, 308 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 0e87c36dcd..ee3611ca93 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -751,20 +751,6 @@ tf_py_test(
)
py_test(
- name = "numpy_io_test",
- size = "small",
- srcs = ["python/learn/learn_io/numpy_io_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":learn",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
name = "pandas_io_test",
size = "small",
srcs = ["python/learn/learn_io/pandas_io_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py
deleted file mode 100644
index 6fe8de8705..0000000000
--- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py
+++ /dev/null
@@ -1,280 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for numpy_io."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.learn.python.learn.learn_io import numpy_io
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-from tensorflow.python.training import coordinator
-from tensorflow.python.training import queue_runner_impl
-
-
-class NumpyIoTest(test.TestCase):
-
- def testNumpyInputFn(self):
- a = np.arange(4) * 1.0
- b = np.arange(32, 36)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -28)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=2, shuffle=False, num_epochs=1)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1])
- self.assertAllEqual(res[0]['b'], [32, 33])
- self.assertAllEqual(res[1], [-32, -31])
-
- session.run([features, target])
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self):
- a = np.arange(2) * 1.0
- b = np.arange(32, 34)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -30)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=128, shuffle=False, num_epochs=2)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1, 0, 1])
- self.assertAllEqual(res[0]['b'], [32, 33, 32, 33])
- self.assertAllEqual(res[1], [-32, -31, -32, -31])
-
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithZeroEpochs(self):
- a = np.arange(4) * 1.0
- b = np.arange(32, 36)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -28)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=2, shuffle=False, num_epochs=0)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self):
- batch_size = 2
- a = np.arange(5) * 1.0
- b = np.arange(32, 37)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -27)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1])
- self.assertAllEqual(res[0]['b'], [32, 33])
- self.assertAllEqual(res[1], [-32, -31])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [2, 3])
- self.assertAllEqual(res[0]['b'], [34, 35])
- self.assertAllEqual(res[1], [-30, -29])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [4])
- self.assertAllEqual(res[0]['b'], [36])
- self.assertAllEqual(res[1], [-28])
-
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self):
- batch_size = 2
- a = np.arange(3) * 1.0
- b = np.arange(32, 35)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -29)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=batch_size, shuffle=False, num_epochs=3)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1])
- self.assertAllEqual(res[0]['b'], [32, 33])
- self.assertAllEqual(res[1], [-32, -31])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [2, 0])
- self.assertAllEqual(res[0]['b'], [34, 32])
- self.assertAllEqual(res[1], [-30, -32])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [1, 2])
- self.assertAllEqual(res[0]['b'], [33, 34])
- self.assertAllEqual(res[1], [-31, -30])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1])
- self.assertAllEqual(res[0]['b'], [32, 33])
- self.assertAllEqual(res[1], [-32, -31])
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [2])
- self.assertAllEqual(res[0]['b'], [34])
- self.assertAllEqual(res[1], [-30])
-
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithBatchSizeLargerThanDataSize(self):
- batch_size = 10
- a = np.arange(4) * 1.0
- b = np.arange(32, 36)
- x = {'a': a, 'b': b}
- y = np.arange(-32, -28)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [0, 1, 2, 3])
- self.assertAllEqual(res[0]['b'], [32, 33, 34, 35])
- self.assertAllEqual(res[1], [-32, -31, -30, -29])
-
- with self.assertRaises(errors.OutOfRangeError):
- session.run([features, target])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithDifferentDimensionsOfFeatures(self):
- a = np.array([[1, 2], [3, 4]])
- b = np.array([5, 6])
- x = {'a': a, 'b': b}
- y = np.arange(-32, -30)
-
- with self.test_session() as session:
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=2, shuffle=False, num_epochs=1)
- features, target = input_fn()
-
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(session, coord=coord)
-
- res = session.run([features, target])
- self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]])
- self.assertAllEqual(res[0]['b'], [5, 6])
- self.assertAllEqual(res[1], [-32, -31])
-
- coord.request_stop()
- coord.join(threads)
-
- def testNumpyInputFnWithXAsNonDict(self):
- x = np.arange(32, 36)
- y = np.arange(4)
- with self.test_session():
- with self.assertRaisesRegexp(TypeError, 'x must be dict'):
- failing_input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=2, shuffle=False, num_epochs=1)
- failing_input_fn()
-
- def testNumpyInputFnWithTargetKeyAlreadyInX(self):
- array = np.arange(32, 36)
- x = {'__target_key__': array}
- y = np.arange(4)
-
- with self.test_session():
- input_fn = numpy_io.numpy_input_fn(
- x, y, batch_size=2, shuffle=False, num_epochs=1)
- input_fn()
- self.assertAllEqual(x['__target_key__'], array)
- self.assertItemsEqual(x.keys(), ['__target_key__'])
-
- def testNumpyInputFnWithMismatchLengthOfInputs(self):
- a = np.arange(4) * 1.0
- b = np.arange(32, 36)
- x = {'a': a, 'b': b}
- x_mismatch_length = {'a': np.arange(1), 'b': b}
- y_longer_length = np.arange(10)
-
- with self.test_session():
- with self.assertRaisesRegexp(
- ValueError, 'Length of tensors in x and y is mismatched.'):
- failing_input_fn = numpy_io.numpy_input_fn(
- x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1)
- failing_input_fn()
-
- with self.assertRaisesRegexp(
- ValueError, 'Length of tensors in x and y is mismatched.'):
- failing_input_fn = numpy_io.numpy_input_fn(
- x=x_mismatch_length,
- y=None,
- batch_size=2,
- shuffle=False,
- num_epochs=1)
- failing_input_fn()
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index 750af20e8a..c4c2e30e87 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -19,7 +19,10 @@ from __future__ import division
from __future__ import print_function
import collections
+
+import numpy as np
from six import string_types
+
from tensorflow.python.estimator.inputs.queues import feeding_functions
# Key name to pack the target into dict of `features`. See
@@ -36,6 +39,13 @@ def _get_unique_target_key(features):
temporarily and unpacked after calling the feeding function. Toward this goal,
this function returns a key not existed in the `features` to pack the
`target`.
+
+ Args:
+ features: OrderedDict of numpy arrays
+
+ Returns:
+ A unique key that can be used to insert the subsequent target into
+ features dict.
"""
target_key = _TARGET_KEY
while target_key in features:
@@ -43,6 +53,39 @@ def _get_unique_target_key(features):
return target_key
+def _validate_and_convert_features(x):
+ """Type check input data and make a shadow copy as an ordered dict.
+
+ Args:
+ x: numpy array object or dict of numpy array objects. If an array,
+ the array will be treated as a single feature.
+
+ Returns:
+ OrderedDict copy of x.
+
+ Raises:
+ ValueError: if x is empty
+ TypeError: if x is an unknown type.
+ """
+ if isinstance(x, dict):
+ if not x:
+ raise ValueError('x cannot be an empty dict')
+ # Make a shadow copy and also ensure the order of iteration is consistent.
+ ordered_dict_data = collections.OrderedDict(
+ sorted(x.items(), key=lambda t: t[0]))
+ elif isinstance(x, np.ndarray):
+ if x.size == 0:
+ raise ValueError('x cannot be an empty array')
+
+ # Make a shadow copy and convert to dict to align with dict processing.
+ ordered_dict_data = collections.OrderedDict({'__direct_np_input__': x})
+ else:
+ x_type = type(x).__name__
+ raise TypeError('x must be a dict or array; got {}'.format(x_type))
+
+ return ordered_dict_data
+
+
def numpy_input_fn(x,
y=None,
batch_size=128,
@@ -70,7 +113,8 @@ def numpy_input_fn(x,
```
Args:
- x: dict of numpy array object.
+ x: numpy array object or dict of numpy array objects. If an array,
+ the array will be treated as a single feature.
y: numpy array object or dict of numpy array object. `None` if absent.
batch_size: Integer, size of batches to return.
num_epochs: Integer, number of epochs to iterate over data. If `None` will
@@ -90,23 +134,19 @@ def numpy_input_fn(x,
values in `x` have same shape).
ValueError: if duplicate keys are in both `x` and `y` when `y` is a dict.
ValueError: if x or y is an empty dict.
- TypeError: `x` is not a dict or `shuffle` is not bool.
+ TypeError: `x` is not a dict or array, or if `shuffle` is not bool.
"""
-
if not isinstance(shuffle, bool):
raise TypeError('shuffle must be explicitly set as boolean; '
'got {}'.format(shuffle))
def input_fn():
"""Numpy input function."""
- if not isinstance(x, dict):
- raise TypeError('x must be dict; got {}'.format(type(x).__name__))
- if not x:
- raise ValueError('x cannot be empty')
- # Make a shadow copy and also ensure the order of iteration is consistent.
- ordered_dict_data = collections.OrderedDict(
- sorted(x.items(), key=lambda t: t[0]))
+ # Note that `x` should not be used after conversion to ordered_dict_data,
+ # as type could be either dict or array.
+ ordered_dict_data = _validate_and_convert_features(x)
+
# Deep copy keys which is a view in python 3
feature_keys = list(ordered_dict_data.keys())
@@ -161,7 +201,13 @@ def numpy_input_fn(x,
if batch:
batch.pop(0)
- features = dict(zip(feature_keys, batch[:len(feature_keys)]))
+ if isinstance(x, np.ndarray):
+ # Return as the same type as original array.
+ features = batch[0]
+ else:
+ # Return as the original dict type
+ features = dict(zip(feature_keys, batch[:len(feature_keys)]))
+
if target_keys is None:
# TODO(martinwicke), return consistent result
return features
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 1374e3f7e1..92d057e25d 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl
@@ -231,10 +232,10 @@ class NumpyIoTest(test.TestCase):
coord.join(threads)
def testNumpyInputFnWithXAsNonDict(self):
- x = np.arange(32, 36)
+ x = list(range(32, 36))
y = np.arange(4)
with self.test_session():
- with self.assertRaisesRegexp(TypeError, 'x must be dict'):
+ with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):
failing_input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
failing_input_fn()
@@ -243,7 +244,15 @@ class NumpyIoTest(test.TestCase):
x = {}
y = np.arange(4)
with self.test_session():
- with self.assertRaisesRegexp(ValueError, 'x cannot be empty'):
+ with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
+ def testNumpyInputFnWithXIsEmptyArray(self):
+ x = np.array([[], []])
+ y = np.arange(4)
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -369,6 +378,82 @@ class NumpyIoTest(test.TestCase):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
+ def testNumpyInputFnWithXIsArray(self):
+ x = np.arange(4) * 1.0
+ y = np.arange(-32, -28)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features, target = input_fn()
+
+ with monitored_session.MonitoredSession() as session:
+ res = session.run([features, target])
+ self.assertAllEqual(res[0], [0, 1])
+ self.assertAllEqual(res[1], [-32, -31])
+
+ session.run([features, target])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features, target])
+
+ def testNumpyInputFnWithXIsNDArray(self):
+ x = np.arange(16).reshape(4, 2, 2) * 1.0
+ y = np.arange(-48, -32).reshape(4, 2, 2)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features, target = input_fn()
+
+ with monitored_session.MonitoredSession() as session:
+ res = session.run([features, target])
+ self.assertAllEqual(res[0], [[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
+ self.assertAllEqual(
+ res[1], [[[-48, -47], [-46, -45]], [[-44, -43], [-42, -41]]])
+
+ session.run([features, target])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features, target])
+
+ def testNumpyInputFnWithXIsArrayYIsDict(self):
+ x = np.arange(4) * 1.0
+ y = {'y1': np.arange(-32, -28)}
+
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_tensor, targets_tensor = input_fn()
+
+ with monitored_session.MonitoredSession() as session:
+ features, targets = session.run([features_tensor, targets_tensor])
+ self.assertEqual(len(features), 2)
+ self.assertAllEqual(features, [0, 1])
+ self.assertEqual(len(targets), 1)
+ self.assertAllEqual(targets['y1'], [-32, -31])
+
+ session.run([features_tensor, targets_tensor])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features_tensor, targets_tensor])
+
+ def testArrayAndDictGiveSameOutput(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x_arr = np.vstack((a, b))
+ x_dict = {'feature1': x_arr}
+ y = np.arange(-48, -40).reshape(2, 4)
+
+ input_fn_arr = numpy_io.numpy_input_fn(
+ x_arr, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_arr, targets_arr = input_fn_arr()
+
+ input_fn_dict = numpy_io.numpy_input_fn(
+ x_dict, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_dict, targets_dict = input_fn_dict()
+
+ with monitored_session.MonitoredSession() as session:
+ res_arr, res_dict = session.run([
+ (features_arr, targets_arr), (features_dict, targets_dict)])
+
+ self.assertAllEqual(res_arr[0], res_dict[0]['feature1'])
+ self.assertAllEqual(res_arr[1], res_dict[1])
+
if __name__ == '__main__':
test.main()