aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/head.py2
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py83
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py87
3 files changed, 148 insertions, 24 deletions
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 62fea05867..fa5d02c476 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -117,7 +117,7 @@ class _Head(object):
update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
loss=estimator_spec.loss, ...)
hooks = [sync.make_session_run_hook(is_chief)]
- ... upate train_op and hooks in EstimatorSpec and return
+ ... update train_op and hooks in EstimatorSpec and return
```
"""
__metaclass__ = abc.ABCMeta
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index c9f37f06e8..3512f66284 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
# Key name to pack the target into dict of `features`. See
@@ -51,8 +52,9 @@ def numpy_input_fn(x,
num_threads=1):
"""Returns input function that would feed dict of numpy arrays into the model.
- This returns a function outputting `features` and `target` based on the dict
- of numpy arrays. The dict `features` has the same keys as the `x`.
+ This returns a function outputting `features` and `targets` based on the dict
+ of numpy arrays. The dict `features` has the same keys as the `x`. The dict
+ `targets` has the same keys as the `y` if `y` is a dict.
Example:
@@ -69,7 +71,7 @@ def numpy_input_fn(x,
Args:
x: dict of numpy array object.
- y: numpy array object. `None` if absent.
+ y: numpy array object or dict of numpy array object. `None` if absent.
batch_size: Integer, size of batches to return.
num_epochs: Integer, number of epochs to iterate over data. If `None` will
run forever.
@@ -81,11 +83,13 @@ def numpy_input_fn(x,
such as in prediction and evaluation mode, `num_threads` should be 1.
Returns:
- Function, that has signature of ()->(dict of `features`, `target`)
+ Function, that has signature of ()->(dict of `features`, `targets`)
Raises:
ValueError: if the shape of `y` mismatches the shape of values in `x` (i.e.,
values in `x` have same shape).
+ ValueError: if duplicate keys are in both `x` and `y` when `y` is a dict.
+ ValueError: if x or y is an empty dict.
TypeError: `x` is not a dict or `shuffle` is not bool.
"""
@@ -97,43 +101,76 @@ def numpy_input_fn(x,
"""Numpy input function."""
if not isinstance(x, dict):
raise TypeError('x must be dict; got {}'.format(type(x).__name__))
+ if not x:
+ raise ValueError('x cannot be empty')
# Make a shadow copy and also ensure the order of iteration is consistent.
- ordered_dict_x = collections.OrderedDict(
+ ordered_dict_data = collections.OrderedDict(
sorted(x.items(), key=lambda t: t[0]))
+ # Deep copy keys which is a view in python 3
+ feature_keys = list(ordered_dict_data.keys())
+
+ if y is None:
+ target_keys = None
+ elif isinstance(y, dict):
+ if not y:
+ raise ValueError('y cannot be empty dict, use None instead.')
+
+ ordered_dict_y = collections.OrderedDict(
+ sorted(y.items(), key=lambda t: t[0]))
+ target_keys = list(ordered_dict_y.keys())
+
+ duplicate_keys = set(feature_keys).intersection(set(target_keys))
+ if len(duplicate_keys):
+ raise ValueError('{} duplicate keys are found in both x and y: '
+ '{}'.format(len(duplicate_keys), duplicate_keys))
+
+ ordered_dict_data.update(ordered_dict_y)
+ else:
+ target_keys = _get_unique_target_key(ordered_dict_data)
+ ordered_dict_data[target_keys] = y
+
+ if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
+ shape_dict_of_x = {k: ordered_dict_data[k].shape
+ for k in feature_keys}
+
+ if target_keys is None:
+ shape_of_y = None
+ elif isinstance(target_keys, string_types):
+ shape_of_y = y.shape
+ else:
+ shape_of_y = {k: ordered_dict_data[k].shape
+ for k in target_keys}
- unique_target_key = _get_unique_target_key(ordered_dict_x)
- if y is not None:
- ordered_dict_x[unique_target_key] = y
-
- if len(set(v.shape[0] for v in ordered_dict_x.values())) != 1:
- shape_dict_of_x = {k: ordered_dict_x[k].shape
- for k in ordered_dict_x.keys()}
- shape_of_y = None if y is None else y.shape
raise ValueError('Length of tensors in x and y is mismatched. All '
'elements in x and y must have the same length.\n'
'Shapes in x: {}\n'
- 'Shape for y: {}\n'.format(shape_dict_of_x, shape_of_y))
+ 'Shapes in y: {}\n'.format(shape_dict_of_x, shape_of_y))
queue = feeding_functions._enqueue_data( # pylint: disable=protected-access
- ordered_dict_x,
+ ordered_dict_data,
queue_capacity,
shuffle=shuffle,
num_threads=num_threads,
enqueue_size=batch_size,
num_epochs=num_epochs)
- features = (queue.dequeue_many(batch_size) if num_epochs is None
+ batch = (queue.dequeue_many(batch_size) if num_epochs is None
else queue.dequeue_up_to(batch_size))
- # Remove the first `Tensor` in `features`, which is the row number.
- if len(features) > 0:
- features.pop(0)
+ # Remove the first `Tensor` in `batch`, which is the row number.
+ if len(batch) > 0:
+ batch.pop(0)
- features = dict(zip(ordered_dict_x.keys(), features))
- if y is not None:
- target = features.pop(unique_target_key)
+ features = dict(zip(feature_keys, batch[:len(feature_keys)]))
+ if target_keys is None:
+ # TODO(martinwicke), return consistent result
+ return features
+ elif isinstance(target_keys, string_types):
+ target = batch[-1]
+ return features, target
+ else:
+ target = dict(zip(target_keys, batch[-len(target_keys):]))
return features, target
- return features
return input_fn
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 02df22b632..65eae7a7dc 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -239,6 +239,40 @@ class NumpyIoTest(test.TestCase):
x, y, batch_size=2, shuffle=False, num_epochs=1)
failing_input_fn()
+ def testNumpyInputFnWithXIsEmptyDict(self):
+ x = {}
+ y = np.arange(4)
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, 'x cannot be empty'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
+ def testNumpyInputFnWithYIsNone(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = None
+
+ with self.test_session() as session:
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_tensor = input_fn()
+
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(session, coord=coord)
+
+ feature = session.run(features_tensor)
+ self.assertEqual(len(feature), 2)
+ self.assertAllEqual(feature['a'], [0, 1])
+ self.assertAllEqual(feature['b'], [32, 33])
+
+ session.run([features_tensor])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features_tensor])
+
+ coord.request_stop()
+ coord.join(threads)
+
def testNumpyInputFnWithNonBoolShuffle(self):
x = np.arange(32, 36)
y = np.arange(4)
@@ -285,6 +319,59 @@ class NumpyIoTest(test.TestCase):
num_epochs=1)
failing_input_fn()
+ def testNumpyInputFnWithYAsDict(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
+
+ with self.test_session() as session:
+ input_fn = numpy_io.numpy_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+ features_tensor, targets_tensor = input_fn()
+
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(session, coord=coord)
+
+ features, targets = session.run([features_tensor, targets_tensor])
+ self.assertEqual(len(features), 2)
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertEqual(len(targets), 2)
+ self.assertAllEqual(targets['y1'], [-32, -31])
+ self.assertAllEqual(targets['y2'], [32, 31])
+
+ session.run([features_tensor, targets_tensor])
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run([features_tensor, targets_tensor])
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def testNumpyInputFnWithYIsEmptyDict(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {}
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
+ def testNumpyInputFnWithDuplicateKeysInXAndY(self):
+ a = np.arange(4) * 1.0
+ b = np.arange(32, 36)
+ x = {'a': a, 'b': b}
+ y = {'y1': np.arange(-32, -28),
+ 'a': a,
+ 'y2': np.arange(32, 28, -1),
+ 'b': b}
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, '2 duplicate keys are found in both x and y'):
+ failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
+ failing_input_fn()
+
if __name__ == '__main__':
test.main()