diff options
author | 2017-06-16 10:15:27 -0700 | |
---|---|---|
committer | 2017-06-16 10:18:33 -0700 | |
commit | 1111e06d9cd691cbdfcb67cf9f234a504f4e0f6d (patch) | |
tree | e3be723bdb4abf8c90eb60f69b735b8a7c2b4231 | |
parent | a6000da882aef2d8631c6c22bcfaa0db01d5336f (diff) |
In filter_dataset, if the element doesn't match, clear the output.
Prior to this change, the output tensors list wouldn't be cleared
if the element didn't match, so output_tensors would be appended to
on the next iteration, violating the number of inputs contract for
the filter function. The new test exhibited this problem. By
clearing the output tensors list when the element doesn't match, we
reset the state back to empty, maintaining the invariant at
the start of the loop.
Fixes #10725.
PiperOrigin-RevId: 159242269
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py | 11 | ||||
-rw-r--r-- | tensorflow/core/kernels/filter_dataset_op.cc | 4 |
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 3ea783ad89..19be94e174 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -72,6 +72,17 @@ class FilterDatasetTest(test.TestCase): # Test an empty dataset. do_test(0, 1) + def testFilterRange(self): + dataset = dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + self.assertEqual(1, sess.run(get_next)) + self.assertEqual(3, sess.run(get_next)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc index 3503c45f9a..c0e909d73a 100644 --- a/tensorflow/core/kernels/filter_dataset_op.cc +++ b/tensorflow/core/kernels/filter_dataset_op.cc @@ -124,6 +124,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { "Filter predicate `f` must return a scalar bool."); } matched = result[0].scalar<bool>()(); + if (!matched) { + // Clear the output tensor list since it didn't match. + out_tensors->clear(); + } } while (!matched); *end_of_sequence = false; return Status::OK(); |