aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-06-16 10:15:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 10:18:33 -0700
commit1111e06d9cd691cbdfcb67cf9f234a504f4e0f6d (patch)
treee3be723bdb4abf8c90eb60f69b735b8a7c2b4231
parenta6000da882aef2d8631c6c22bcfaa0db01d5336f (diff)
In filter_dataset, if the element doesn't match, clear the output.
Prior to this change, the output tensors list wouldn't be cleared if the element didn't match, so output_tensors would be appended to on the next iteration, violating the number of inputs contract for the filter function. The new test exhibited this problem. By clearing the output tensors list when the element doesn't match, we reset the state back to empty, maintaining the invariant at the start of the loop. Fixes #10725. PiperOrigin-RevId: 159242269
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py11
-rw-r--r--tensorflow/core/kernels/filter_dataset_op.cc4
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
index 3ea783ad89..19be94e174 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
@@ -72,6 +72,17 @@ class FilterDatasetTest(test.TestCase):
# Test an empty dataset.
do_test(0, 1)
+ def testFilterRange(self):
+ dataset = dataset_ops.Dataset.range(100).filter(
+ lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+ self.assertEqual(1, sess.run(get_next))
+ self.assertEqual(3, sess.run(get_next))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc
index 3503c45f9a..c0e909d73a 100644
--- a/tensorflow/core/kernels/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/filter_dataset_op.cc
@@ -124,6 +124,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
"Filter predicate `f` must return a scalar bool.");
}
matched = result[0].scalar<bool>()();
+ if (!matched) {
+ // Clear the output tensor list since it didn't match.
+ out_tensors->clear();
+ }
} while (!matched);
*end_of_sequence = false;
return Status::OK();