aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-06-29 11:10:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-29 11:15:12 -0700
commit9b11f458196f6f0528c9974238497a6c8b6da547 (patch)
tree69a94a9f8ce23944aa6e4c47d00e98644e716e5b /tensorflow/contrib
parentc1087b3a0b851b62a027201c0c41c0bd4e44e303 (diff)
[tf.contrib.data] Fix the handling of dict-typed elements in functions.
Previously, we were treating a `dict` as a sequence, which led to incorrect behavior like passing all of the dict's keys rather than values as the arguments to a map or filter function. This change changes the behavior so that the dict is passed as a single argument to these functions. It additionally fixes the ported version of `nest.flatten_up_to()` so that `Dataset.padded_batch()` works with dict-typed elements. Fixes #11016. PiperOrigin-RevId: 160548475
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py24
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py16
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py13
-rw-r--r--tensorflow/contrib/data/python/util/nest.py4
-rw-r--r--tensorflow/contrib/data/python/util/nest_test.py8
7 files changed, 86 insertions, 13 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 20d66d7f23..71df1ee0a5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class BucketingTest(test.TestCase):
+class GroupByWindowTest(test.TestCase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -257,16 +257,24 @@ class BucketTest(test.TestCase):
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
+ return {"x": v,
+ "y": array_ops.fill([v], v),
+ "z": array_ops.fill([3], string_ops.as_string(v))}
+
+ def _dynamic_pad_fn(bucket, window, _):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
+ 32, {"x": tensor_shape.TensorShape([]),
+ "y": tensor_shape.TensorShape([None]),
+ "z": tensor_shape.TensorShape([3])})))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
- .filter(lambda x, y, z: math_ops.equal(x % 2, 0)))
+ .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
bucketed_dataset = input_dataset.group_by_window(
- lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
+ lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
+ lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@@ -283,9 +291,9 @@ class BucketTest(test.TestCase):
self.assertAllEqual(0, which_bucket0)
self.assertAllEqual(0, which_bucket1)
self.assertAllEqual(
- np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0[0])
+ np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
self.assertAllEqual(
- np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1[0])
+ np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
index 19be94e174..e6d50dc154 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
@@ -83,6 +83,23 @@ class FilterDatasetTest(test.TestCase):
self.assertEqual(1, sess.run(get_next))
self.assertEqual(3, sess.run(get_next))
+ def testFilterDict(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .map(lambda x: {"foo": x * 2, "bar": x ** 2})
+ .filter(lambda d: math_ops.equal(d["bar"] % 2, 0))
+ .map(lambda d: d["foo"] + d["bar"])
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for i in range(10):
+ if (i ** 2) % 2 == 0:
+ self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
index 3c9c714bde..ace0dd3668 100644
--- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
@@ -101,6 +101,23 @@ class FlatMapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess = random.choice([sess1, sess2])
sess.run(get_next)
+
+ def testMapDict(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .map(lambda x: {"foo": x * 2, "bar": x ** 2})
+ .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"])
+ .repeat(d["bar"]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for i in range(10):
+ for _ in range(i ** 2):
+ self.assertEqual(i * 2, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
# pylint: enable=g-long-lambda
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index b5956ac49c..2c07248c54 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -324,5 +324,21 @@ class MapDatasetTest(test.TestCase):
# Randomness is repeatable given same seed
self.assertAllClose(random_values, random_values_2)
+ def testMapDict(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .map(lambda x: {"foo": x * 2, "bar": x ** 2})
+ .map(lambda d: d["foo"] + d["bar"])
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index 29f1209a58..a689bfc901 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -1357,6 +1357,11 @@ class DenseToSparseBatchDataset(Dataset):
return (dtypes.int64, self._input_dataset.output_types, dtypes.int64)
+def _should_unpack_args(args):
+ """Returns `True` if `args` should be `*args` when passed to a callable."""
+ return nest.is_sequence(args) and not isinstance(args, dict)
+
+
class _ResourceDataset(Dataset):
"""A Dataset wrapper for a tf.resource-typed function argument."""
@@ -1394,7 +1399,7 @@ class GroupByWindowDataset(Dataset):
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
arg.set_shape(shape)
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- if nest.is_sequence(nested_args):
+ if _should_unpack_args(nested_args):
ret = key_func(*nested_args)
else:
ret = key_func(nested_args)
@@ -1483,7 +1488,7 @@ class MapDataset(Dataset):
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- if nest.is_sequence(nested_args):
+ if _should_unpack_args(nested_args):
ret = map_func(*nested_args)
else:
ret = map_func(nested_args)
@@ -1559,7 +1564,7 @@ class FlatMapDataset(Dataset):
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- if nest.is_sequence(nested_args):
+ if _should_unpack_args(nested_args):
dataset = map_func(*nested_args)
else:
dataset = map_func(nested_args)
@@ -1609,7 +1614,7 @@ class FilterDataset(Dataset):
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- if nest.is_sequence(nested_args):
+ if _should_unpack_args(nested_args):
ret = predicate(*nested_args)
else:
ret = predicate(nested_args)
diff --git a/tensorflow/contrib/data/python/util/nest.py b/tensorflow/contrib/data/python/util/nest.py
index 91c8416d5a..a29c3c562b 100644
--- a/tensorflow/contrib/data/python/util/nest.py
+++ b/tensorflow/contrib/data/python/util/nest.py
@@ -286,7 +286,8 @@ def map_structure(func, *structure, **check_types_dict):
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
- for shallow_branch, input_branch in zip(shallow_tree, input_tree):
+ for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
+ _elements_of(input_tree)):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:
@@ -495,6 +496,7 @@ def map_structure_up_to(shallow_tree, func, *inputs):
# then repack based on the structure of the first input.
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
for input_tree in inputs]
+
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
diff --git a/tensorflow/contrib/data/python/util/nest_test.py b/tensorflow/contrib/data/python/util/nest_test.py
index 7852e4f861..5132881afb 100644
--- a/tensorflow/contrib/data/python/util/nest_test.py
+++ b/tensorflow/contrib/data/python/util/nest_test.py
@@ -287,6 +287,14 @@ class NestTest(test.TestCase):
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
+ # Using dict.
+ input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
+ shallow_tree = {"a": (True, True), "b": (False, True)}
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
+ self.assertEqual(flattened_shallow_tree, [True, True, False, True])
+
def testMapStructureUpTo(self):
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")