diff options
author | 2017-09-06 20:22:16 -0700 | |
---|---|---|
committer | 2017-09-06 20:26:27 -0700 | |
commit | b5fdf92df80f895c29258dcabfec7b41bf3604ae (patch) | |
tree | 50a7af20c0e17d68dd4d82275dc1e64d7f9bcc66 | |
parent | ed498a28c3720c01f96790c646e84764c54d9310 (diff) |
Add sloppy_interleave dataset operator.
When feeding data at high speed into a model from variable-latency data
sources, head-of-line blocking can be a significant concern when using a
deterministic input pipeline, such as interleave.
This change introduces a new non-deterministic dataset operator that avoids
head-of-line blocking.
PiperOrigin-RevId: 167810743
-rw-r--r-- | tensorflow/contrib/data/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 21 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py | 475 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/sloppy_ops.py | 120 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/core/kernels/dataset_utils.cc | 78 | ||||
-rw-r--r-- | tensorflow/core/kernels/dataset_utils.h | 35 | ||||
-rw-r--r-- | tensorflow/core/kernels/flat_map_dataset_op.cc | 56 | ||||
-rw-r--r-- | tensorflow/core/kernels/interleave_dataset_op.cc | 62 | ||||
-rw-r--r-- | tensorflow/core/kernels/sloppy_interleave_dataset_op.cc | 370 | ||||
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 27 |
13 files changed, 1184 insertions, 111 deletions
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7b916d82c1..c417650a96 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -10,6 +10,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 1c0a5288f7..c74e1369d5 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,6 +23,8 @@ @@read_batch_features @@rejection_resample @@group_by_window +@@sloppy_interleave +@@sloppy_map """ from __future__ import absolute_import @@ -38,6 +40,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset +from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index fb2740ffef..2f93c34502 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -147,6 +147,25 @@ py_test( ) py_test( + name = "sloppy_transformation_dataset_op_test", + size = "small", + srcs = ["sloppy_transformation_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + +py_test( name = "list_files_dataset_op_test", size = "small", srcs = ["list_files_dataset_op_test.py"], @@ -228,7 +247,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data", + "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py new file mode 100644 index 0000000000..f9198bacfb --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py @@ -0,0 +1,475 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import math +import threading +import time + +from six.moves import zip_longest + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import sloppy_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class SloppyInterleaveDatasetTest(test.TestCase): + + def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + self.repeat_count = 2 + + # Set up threading events used to sequence when items are produced that + # are subsequently interleaved. These events allow us to deterministically + # simulate slowdowns and force sloppiness. + self.read_coordination_events = {} + self.write_coordination_events = {} + # input values [4, 5, 6] are the common case for the tests; set defaults + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i] = threading.Event() + + def map_py_fn(x): + self.write_coordination_events[x].wait() + self.write_coordination_events[x].clear() + self.read_coordination_events[x].release() + return x * x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset.map(map_fn) + + self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn, self.cycle_length, + self.block_length))) + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + def _interleave(self, lists, cycle_length, block_length): + """Python implementation of interleave used for testing.""" + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip(expected_elements, + self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [ + 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, + 6, 5, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationBlockLength(self): + input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 + expected_elements = [ + 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, + 5, 6, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 2))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationEmptyLists(self): + input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], + [6, 6, 6, 6, 6, 6]] + + expected_elements = [ + 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def _clear_coordination_events(self): + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i].clear() + + def _allow_all_map_threads(self): + for i in range(4, 7): + self.write_coordination_events[i].set() + + def testSingleThreaded(self): + # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and + # `Dataset.flat_map()` and is single-threaded. No synchronization required. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 1, + self.block_length: 1 + }) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): + self.write_coordination_events[expected_element].set() + self.assertEqual(expected_element * expected_element, + sess.run(self.next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContention(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRaces(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionBlockLength(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEmptyInput(self): + with self.test_session() as sess: + # Empty input. + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testNonEmptyInputIntoEmptyOutputs(self): + # Non-empty input leading to empty output. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [0, 0, 0], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testPartiallyEmptyOutputs(self): + # Mixture of non-empty and empty interleaved datasets. + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 0, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testDelayedOutput(self): + # Explicitly control the sequence of events to ensure we correctly avoid + # head-of-line blocking. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + + mis_ordering = [ + 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, + 6, 5, 5, 5, 5, 6, 6 + ] + for element in mis_ordering: + self.write_coordination_events[element].set() + self.assertEqual(element * element, sess.run(self.next_element)) + self.assertTrue(self.read_coordination_events[element].acquire(False)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testBlockLengthWithContention(self): + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 3 + }) + # Test against a generating sequence that differs from the uncontended + # case, in order to prove sloppy correctness. + for i, expected_element in enumerate( + self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, + cycle_length=2, + block_length=2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEarlyExit(self): + # Exiting without consuming all input should not block + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 3, + self.block_length: 2 + }) + for i in range(4, 7): + self.write_coordination_events[i].set() + elem = sess.run(self.next_element) # Start all workers + # Allow the one successful worker to progress beyond the py_func again. + elem = int(math.sqrt(elem)) + self.write_coordination_events[elem].set() + self.read_coordination_events[elem].acquire() + # Allow the prefetch to succeed + for i in range(4, 7): + self.read_coordination_events[i].acquire() + self.write_coordination_events[i].set() + + def testTooManyReaders(self): + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) + return dataset + + dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) + dataset = dataset.repeat(self.repeat_count) + dataset = dataset.apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn,), + kwargs={"cycle_length": 16, + "block_length": 2}) + iterator = dataset.make_one_shot_iterator() + + with self.test_session() as sess: + output_values = [] + for _ in range(30): + output_values.append(sess.run(iterator.get_next())) + + expected_values = self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) + self.assertItemsEqual(output_values, expected_values) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 8afd122d82..94969c1c70 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -32,6 +32,21 @@ py_library( ], ) +py_library( + name = "sloppy_ops", + srcs = ["sloppy_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + "//tensorflow/contrib/data/python/framework:function", + "//tensorflow/contrib/data/python/util:nest", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py new file mode 100644 index 0000000000..010bd31161 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Non-deterministic dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.framework import function +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class SloppyInterleaveDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over its input and flattens the result.""" + + def __init__(self, input_dataset, map_func, cycle_length, block_length): + """See `tf.contrib.data.sloppy_interleave()` for details.""" + super(SloppyInterleaveDataset, self).__init__() + self._input_dataset = input_dataset + + @function.Defun(*nest.flatten(input_dataset.output_types)) + def tf_map_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + + if nest.is_sequence(nested_args): + dataset = map_func(*nested_args) + else: + dataset = map_func(nested_args) + + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`map_func` must return a `Dataset` object.") + + self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes + + return dataset.make_dataset_resource() + + self._map_func = tf_map_func + self._map_func.add_to_graph(ops.get_default_graph()) + + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + + def make_dataset_resource(self): + return gen_dataset_ops.sloppy_interleave_dataset( + self._input_dataset.make_dataset_resource(), + self._map_func.captured_inputs, + self._cycle_length, + self._block_length, + f=self._map_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def sloppy_interleave(dataset, map_func, cycle_length, block_length): + """Maps `map_func` across `dataset`, and interleaves the results. + + The resulting dataset is almost identical to `interleave`. The key + difference being that if retrieving a value from a given output iterator would + cause `get_next` to block, that iterator will be skipped, and consumed + when next available. If consuming from all iterators would cause the + `get_next` call to block, the `get_next` call blocks until the first value is + available. + + If the underlying datasets produce elements as fast as they are consumed, the + `sloppy_interleave` dataset behaves identically to the `interleave` dataset. + However, if an underlying dataset would block the consumer, the + `sloppy_interleave` dataset can violate to the round-robin order (respected by + the `interleave` dataset), producing an element from a different underlying + dataset instead. + + WARNING: The order of elements in the resulting dataset is not + deterministic. Use `Dataset.interleave()` if you want the elements to have a + deterministic order. + + Args: + dataset: A `Dataset` that produces elements to feed to `map_func`. + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + `Dataset`. + cycle_length: The number of threads to interleave from in parallel. + block_length: The number of consecutive elements to pull from a thread + before advancing to the next thread. Note: sloppy_interleave will + skip the remainder of elements in the block_length in order to avoid + blocking. + + Returns: + A `Dataset`. + """ + return SloppyInterleaveDataset(dataset, map_func, cycle_length, block_length) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8dd8900f28..32a1b2c84d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5588,6 +5588,20 @@ cc_library( ) cc_library( + name = "dataset_utils", + srcs = ["dataset_utils.cc"], + hdrs = ["dataset_utils.h"], + deps = [ + ":captured_function", + ":dataset", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/util/tensor_bundle", + ], +) + +cc_library( name = "captured_function", srcs = ["captured_function.cc"], hdrs = ["captured_function.h"], @@ -5713,6 +5727,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5727,6 +5742,22 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "sloppy_interleave_dataset_op", + srcs = ["sloppy_interleave_dataset_op.cc"], + deps = [ + ":captured_function", + ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5963,6 +5994,7 @@ tf_kernel_library( ":repeat_dataset_op", ":shuffle_dataset_op", ":skip_dataset_op", + ":sloppy_interleave_dataset_op", ":sparse_tensor_slice_dataset_op", ":sql_dataset_ops", ":take_dataset_op", diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc new file mode 100644 index 0000000000..f320b3b09c --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset_utils.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector<Tensor>& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr<IteratorBase>* out_iterator) { + FunctionLibraryRuntime::Options opts; + opts.runner = ctx->runner(); + // Choose a step ID that is guaranteed not to clash with any + // Session-generated step ID. DirectSession only generates + // non-negative step IDs (contiguous, starting from 0), and + // MasterSession generates 56-bit random step IDs whose MSB + // is always 0, so a negative random step ID should suffice. + opts.step_id = CapturedFunction::generate_step_id(); + ScopedStepContainer step_container( + opts.step_id, [captured_func, ctx](const string& name) { + captured_func->resource_manager()->Cleanup(name).IgnoreError(); + }); + opts.step_container = &step_container; + std::vector<Tensor> return_values; + TF_RETURN_IF_ERROR(captured_func->Run(opts, input_element, &return_values)); + + if (!(return_values.size() == 1 && return_values[0].dtype() == DT_RESOURCE && + TensorShapeUtils::IsScalar(return_values[0].shape()))) { + return errors::InvalidArgument( + "Function must return a single scalar of dtype DT_RESOURCE."); + } + + // Retrieve the dataset that was created in `f`. + DatasetBase* returned_dataset; + const ResourceHandle& dataset_resource = + return_values[0].scalar<ResourceHandle>()(); + + // NOTE(mrry): We cannot use the core `LookupResource()` or + // `DeleteResource()` functions, because we have an + // `IteratorContext*` and not an `OpKernelContext*`, so we + // replicate the necessary functionality here. + auto type_index = MakeTypeIndex<DatasetBase>(); + if (type_index.hash_code() != dataset_resource.hash_code()) { + return errors::InvalidArgument("Function must return a Dataset resource."); + } + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Lookup( + dataset_resource.container(), dataset_resource.name(), + &returned_dataset)); + core::ScopedUnref unref_dataset(returned_dataset); + + // Create an iterator for the dataset that was returned by + // `f`. This transfers ownership of the dataset to the + // iterator, so we can delete it from the resource manager. + *out_iterator = returned_dataset->MakeIterator( + strings::StrCat(prefix, "[", thread_index, "]")); + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Delete<DatasetBase>( + dataset_resource.container(), dataset_resource.name())); + return Status::OK(); +} + +} // namespace dataset + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dataset_utils.h b/tensorflow/core/kernels/dataset_utils.h new file mode 100644 index 0000000000..eea2b8802b --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector<Tensor>& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr<IteratorBase>* out_iterator); + +} // namespace dataset + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc index e2310fecc7..a87e54bf31 100644 --- a/tensorflow/core/kernels/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/flat_map_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -125,58 +126,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector<Tensor> return_values; - TF_RETURN_IF_ERROR( - dataset()->captured_func_->Run(opts, args, &return_values)); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar<ResourceHandle>()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex<DatasetBase>(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument( - "`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - current_element_iterator_ = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", element_index_++, "]")); - TF_RETURN_IF_ERROR( - dataset() - ->captured_func_->resource_manager() - ->Delete<DatasetBase>(dataset_resource.container(), - dataset_resource.name())); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, element_index_++, dataset()->captured_func_.get(), + prefix(), ¤t_element_iterator_)); } while (true); } diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc index dce4f88101..7b148b74c9 100644 --- a/tensorflow/core/kernels/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/interleave_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -168,8 +169,9 @@ class InterleaveDatasetOp : public OpKernel { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &args, &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( - ctx, args, ¤t_elements_[cycle_index_])); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, cycle_index_, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[cycle_index_])); ++num_open_; } } else { @@ -182,62 +184,6 @@ class InterleaveDatasetOp : public OpKernel { } private: - Status MakeIteratorFromInputElement( - IteratorContext* ctx, const std::vector<Tensor>& input_element, - std::unique_ptr<IteratorBase>* out_iterator) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector<Tensor> return_values; - TF_RETURN_IF_ERROR(dataset()->captured_func_->Run(opts, input_element, - &return_values)); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar<ResourceHandle>()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex<DatasetBase>(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument("`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - *out_iterator = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", cycle_index_, "]")); - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Delete<DatasetBase>( - dataset_resource.container(), dataset_resource.name())); - return Status::OK(); - } - mutex mu_; const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::vector<std::unique_ptr<IteratorBase>> current_elements_ diff --git a/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc new file mode 100644 index 0000000000..d95f51f0f2 --- /dev/null +++ b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc @@ -0,0 +1,370 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset_utils.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" + +#include "tensorflow/core/kernels/captured_function.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + std::vector<Tensor> other_arguments; + other_arguments.reserve(inputs.size()); + for (const Tensor& t : inputs) { + other_arguments.push_back(t); + } + + int64 cycle_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_, + std::move(other_arguments), + &captured_func)); + + *output = new Dataset(input, std::move(captured_func), cycle_length, + block_length, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(const DatasetBase* input, + std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, + int64 block_length, const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : input_(input), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIterator( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { + return "SloppyInterleaveDatasetOp::Dataset"; + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)), + output_elements_(params.dataset->cycle_length_) {} + + ~Iterator() override { + mutex_lock l(mu_); + cancelled_ = true; + // Notify all workers in case they are blocked. + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + output_elements_[i].cond_var.notify_all(); + } + } + + // It is implemented so that it matches the deterministic interleave + // unless we would block waiting for an element, at which point it skips + // along to the next available value. + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); + // Search for available items, blocking if necessary. + while (!cancelled_) { + for (size_t i = 0; i < dataset()->cycle_length_; ++i) { + size_t index = (next_index_ + i) % dataset()->cycle_length_; + if (output_elements_[index].is_produced) { + next_index_ = index; + if (i == 0) { + block_count_++; + if (block_count_ == dataset()->block_length_) { + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + } + } else { + block_count_ = 0; + } + // If we encounter an EoF, advance to the next iterator + if (output_elements_[index].end_of_sequence) { + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + i = -1; // Restart the inner loop + continue; + } + *end_of_sequence = false; + if (output_elements_[index].output_status.ok()) { + output_elements_[index].output_value.swap(*out_tensors); + } + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + return output_elements_[index].output_status; + } + } + + if (num_active_threads_ == 0) { + // No potential for future values. + // + // Note: this condition check must occur after checking the output + // buffer, as its possible for there to be values in the output + // buffer, even if the number of live threads is zero. + *end_of_sequence = true; + return Status::OK(); + } + // No values available; wait until woken up. + cond_var_.wait(l); + } + return errors::Cancelled( + "SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext"); + } + + private: + // Internal structure to manage thread coordination. All values are + // guarded by the enclosing Iterator's mu_. + struct OutputBufferElement { + // The producer must set `is_produced` to `true` after + // `output_status` or `output_value` has been written. + bool is_produced = false; + // The producer sets `output_status` if either getting the input element + // or applying the function to it fails. + Status output_status; + // Reached end of sequence for the underlying iterator. + bool end_of_sequence = false; + // The output data element. + std::vector<Tensor> output_value; + // The producer thread waits on this condition variable after having + // produced an element. The reader thread notifies this condition + // variable after reading the value. + condition_variable cond_var; + }; + + Status EnsureWorkerThreadsStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (worker_threads_.empty()) { + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + // Serialize the creation of the workers and their corresponding + // input elements to ensure we match the standard interleave when + // the underlying iterators induce no delay. + std::vector<Tensor> args; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &args, &end_of_input_)); + if (end_of_input_) { + LOG(WARNING) << "Input iterator exhausted after " << i + << " elements; cannot start all " + << dataset()->cycle_length_ << " worker threads."; + return Status::OK(); + } + std::unique_ptr<IteratorBase> itr; + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr)); + worker_threads_.emplace_back( + std::unique_ptr<Thread>(ctx->env()->StartThread( + {}, "worker_thread", + std::bind(&Iterator::WorkerThread, this, + new IteratorContext(*ctx), i, itr.release())))); + num_active_threads_ = i + 1; + } + } + return Status::OK(); + } + + void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index, + const Status& status, + bool end_of_sequence, + std::vector<Tensor>* out_tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // We have produced an element; push it into the output buffer + // when space is available. + while (!cancelled_ && output_elements_[thread_index].is_produced) { + output_elements_[thread_index].cond_var.wait(*l); + } + if (cancelled_) { + return; + } + output_elements_[thread_index].is_produced = true; + output_elements_[thread_index].output_status = status; + output_elements_[thread_index].end_of_sequence = end_of_sequence; + if (status.ok()) { + output_elements_[thread_index].output_value.swap(*out_tensors); + } else { + output_elements_[thread_index].output_value.clear(); + } + cond_var_.notify_one(); + } + + // Races to produce elements into the output queue buffers. + void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index, + IteratorBase* out_iterator_ptr) { + // std::function arguments are copy-constructable, so we pass raw + // pointers, and then immediately wrap them to ensure correct ownership. + std::unique_ptr<IteratorContext> ctx(ctx_ptr); + std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr); + auto cleanup = gtl::MakeCleanup([this, thread_index] { + mutex_lock l(mu_); + num_active_threads_--; + cond_var_.notify_all(); + }); + while (true) { + // Attempt to produce an element. + bool end_of_out_itr_input = false; + std::vector<Tensor> out_tensors; + Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors, + &end_of_out_itr_input); + // Handle output. + { + mutex_lock l(mu_); + BlockAndUpdateOutputBuffer(&l, thread_index, element_status, + end_of_out_itr_input, &out_tensors); + if (end_of_out_itr_input) { + // We have exhausted our current iterator; get a new iterator; + // loop to handle errors. + while (!cancelled_) { + if (end_of_input_) { + // No more iterator inputs; we're done! + return; + } + std::vector<Tensor> args; + // BlockAndUpdateOutputBuffer() sequences calls to + // input_impl_->GetNext when the out_iterator doesn't cause + // slopping. + Status input_status = + input_impl_->GetNext(ctx.get(), &args, &end_of_input_); + if (end_of_input_) { + // No more elements to produce, stop the worker thread. + return; + } + if (input_status.ok()) { + input_status = dataset::MakeIteratorFromInputElement( + ctx.get(), args, thread_index, + dataset()->captured_func_.get(), prefix(), &out_iterator); + } + if (input_status.ok()) { + // Successfully have a new out_iterator; restart the outer + // loop to produce an element. + break; + } + + // We encountered an error; push the error to the output buffer. + BlockAndUpdateOutputBuffer(&l, thread_index, input_status, + /* end_of_sequence = */ false, + &out_tensors); + } + } + + // Check if we should exit. + if (cancelled_) { + return; + } + } + } + } + + // Mutex & condition variable to guard mutable iterator internals and + // coordinate among worker threads and client thread[s]. + mutex mu_; + condition_variable cond_var_; + // The iterator producing elements which are converted to datasets by + // the dataset()->captured_func_ then interleaved together. + const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + // Whether the input_impl_ can produce future elements. + bool end_of_input_ GUARDED_BY(mu_) = false; + // The buffer of elements to be produced. Each worker thread operates + // on a single OutputBufferElement. + std::vector<OutputBufferElement> output_elements_ GUARDED_BY(mu_); + // The index into output_elements_ for next element to produce. + size_t next_index_ GUARDED_BY(mu_) = 0; + // The number of items produced so far within the block + size_t block_count_ GUARDED_BY(mu_) = 0; + // Number of active threads. + size_t num_active_threads_ GUARDED_BY(mu_) = 0; + // Flag to instruct the worker threads to exit. + bool cancelled_ GUARDED_BY(mu_) = false; + // Pointers to the worker threads. This must be last to ensure the + // threads have exited before any other members are deallocated. + // TODO(b/65178177): Avoid allocating additional threads. + std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + const std::unique_ptr<CapturedFunction> captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + const NameAttrList* func_; +}; + +REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU), + SloppyInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 37d9a737e2..7cc8dccb95 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -233,6 +233,33 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); +REGISTER_OP("SloppyInterleaveDataset") + .Input("input_dataset: resource") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Output("handle: resource") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset resource that contains elements matching + `output_types` and `output_shapes`. +)doc"); + REGISTER_OP("GroupByWindowDataset") .Input("input_dataset: resource") .Input("key_func_other_arguments: Tkey_func_other_arguments") |