aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-09-06 20:22:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 20:26:27 -0700
commitb5fdf92df80f895c29258dcabfec7b41bf3604ae (patch)
tree50a7af20c0e17d68dd4d82275dc1e64d7f9bcc66
parented498a28c3720c01f96790c646e84764c54d9310 (diff)
Add sloppy_interleave dataset operator.
When feeding data at high speed into a model from variable-latency data sources, head-of-line blocking can be a significant concern when using a deterministic input pipeline, such as interleave. This change introduces a new non-deterministic dataset operator that avoids head-of-line blocking. PiperOrigin-RevId: 167810743
-rw-r--r--tensorflow/contrib/data/BUILD1
-rw-r--r--tensorflow/contrib/data/__init__.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD21
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py475
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD15
-rw-r--r--tensorflow/contrib/data/python/ops/sloppy_ops.py120
-rw-r--r--tensorflow/core/kernels/BUILD32
-rw-r--r--tensorflow/core/kernels/dataset_utils.cc78
-rw-r--r--tensorflow/core/kernels/dataset_utils.h35
-rw-r--r--tensorflow/core/kernels/flat_map_dataset_op.cc56
-rw-r--r--tensorflow/core/kernels/interleave_dataset_op.cc62
-rw-r--r--tensorflow/core/kernels/sloppy_interleave_dataset_op.cc370
-rw-r--r--tensorflow/core/ops/dataset_ops.cc27
13 files changed, 1184 insertions, 111 deletions
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 7b916d82c1..c417650a96 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -10,6 +10,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
+ "//tensorflow/contrib/data/python/ops:sloppy_ops",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 1c0a5288f7..c74e1369d5 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -23,6 +23,8 @@
@@read_batch_features
@@rejection_resample
@@group_by_window
+@@sloppy_interleave
+@@sloppy_map
"""
from __future__ import absolute_import
@@ -38,6 +40,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample
from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset
from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset
+from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index fb2740ffef..2f93c34502 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -147,6 +147,25 @@ py_test(
)
py_test(
+ name = "sloppy_transformation_dataset_op_test",
+ size = "small",
+ srcs = ["sloppy_transformation_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:dataset_ops",
+ "//tensorflow/contrib/data/python/ops:sloppy_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "list_files_dataset_op_test",
size = "small",
srcs = ["list_files_dataset_op_test.py"],
@@ -228,7 +247,7 @@ py_test(
srcs = ["sql_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data",
+ "//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py
new file mode 100644
index 0000000000..f9198bacfb
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py
@@ -0,0 +1,475 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import math
+import threading
+import time
+
+from six.moves import zip_longest
+
+from tensorflow.contrib.data.python.ops import dataset_ops
+from tensorflow.contrib.data.python.ops import sloppy_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class SloppyInterleaveDatasetTest(test.TestCase):
+
+ def setUp(self):
+ self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
+ self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
+ self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
+
+ self.repeat_count = 2
+
+ # Set up threading events used to sequence when items are produced that
+ # are subsequently interleaved. These events allow us to deterministically
+ # simulate slowdowns and force sloppiness.
+ self.read_coordination_events = {}
+ self.write_coordination_events = {}
+ # input values [4, 5, 6] are the common case for the tests; set defaults
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i] = threading.Event()
+
+ def map_py_fn(x):
+ self.write_coordination_events[x].wait()
+ self.write_coordination_events[x].clear()
+ self.read_coordination_events[x].release()
+ return x * x
+
+ def map_fn(x):
+ return script_ops.py_func(map_py_fn, [x], x.dtype)
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(x)
+ return dataset.map(map_fn)
+
+ self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
+ .repeat(self.repeat_count).apply(
+ sloppy_ops.sloppy_interleave,
+ args=(interleave_fn, self.cycle_length,
+ self.block_length)))
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ def _interleave(self, lists, cycle_length, block_length):
+ """Python implementation of interleave used for testing."""
+ num_open = 0
+
+ # `all_iterators` acts as a queue of iterators over each element of `lists`.
+ all_iterators = [iter(l) for l in lists]
+
+ # `open_iterators` are the iterators whose elements are currently being
+ # interleaved.
+ open_iterators = []
+ for i in range(cycle_length):
+ if all_iterators:
+ open_iterators.append(all_iterators.pop(0))
+ num_open += 1
+ else:
+ open_iterators.append(None)
+
+ while num_open or all_iterators:
+ for i in range(cycle_length):
+ if open_iterators[i] is None:
+ if all_iterators:
+ open_iterators[i] = all_iterators.pop(0)
+ num_open += 1
+ else:
+ continue
+ for _ in range(block_length):
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ open_iterators[i] = None
+ num_open -= 1
+ break
+
+ def testPythonImplementation(self):
+ input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+ [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
+
+ # Cycle length 1 acts like `Dataset.flat_map()`.
+ expected_elements = itertools.chain(*input_lists)
+ for expected, produced in zip(expected_elements,
+ self._interleave(input_lists, 1, 1)):
+ self.assertEqual(expected, produced)
+
+ # Cycle length > 1.
+ expected_elements = [
+ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
+ 6, 5, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationBlockLength(self):
+ input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
+ expected_elements = [
+ 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
+ 5, 6, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationEmptyLists(self):
+ input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
+ [6, 6, 6, 6, 6, 6]]
+
+ expected_elements = [
+ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def _clear_coordination_events(self):
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i].clear()
+
+ def _allow_all_map_threads(self):
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+
+ def testSingleThreaded(self):
+ # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
+ # `Dataset.flat_map()` and is single-threaded. No synchronization required.
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 1,
+ self.block_length: 1
+ })
+
+ for expected_element in self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
+ self.write_coordination_events[expected_element].set()
+ self.assertEqual(expected_element * expected_element,
+ sess.run(self.next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContention(self):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRaces(self):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the prevous test which carefully sequences
+ the execution of the map functions.
+ """
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionBlockLength(self):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRacesAndBlocking(self):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the prevous test which carefully sequences
+ the execution of the map functions.
+ """
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testEmptyInput(self):
+ with self.test_session() as sess:
+ # Empty input.
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [],
+ self.cycle_length: 2,
+ self.block_length: 3
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testNonEmptyInputIntoEmptyOutputs(self):
+ # Non-empty input leading to empty output.
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [0, 0, 0],
+ self.cycle_length: 2,
+ self.block_length: 3
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testPartiallyEmptyOutputs(self):
+ # Mixture of non-empty and empty interleaved datasets.
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 0, 6],
+ self.cycle_length: 2,
+ self.block_length: 1
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testDelayedOutput(self):
+ # Explicitly control the sequence of events to ensure we correctly avoid
+ # head-of-line blocking.
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1
+ })
+
+ mis_ordering = [
+ 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6,
+ 6, 5, 5, 5, 5, 6, 6
+ ]
+ for element in mis_ordering:
+ self.write_coordination_events[element].set()
+ self.assertEqual(element * element, sess.run(self.next_element))
+ self.assertTrue(self.read_coordination_events[element].acquire(False))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testBlockLengthWithContention(self):
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 3
+ })
+ # Test against a generating sequence that differs from the uncontended
+ # case, in order to prove sloppy correctness.
+ for i, expected_element in enumerate(
+ self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
+ cycle_length=2,
+ block_length=2)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testEarlyExit(self):
+ # Exiting without consuming all input should not block
+ with self.test_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 3,
+ self.block_length: 2
+ })
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+ elem = sess.run(self.next_element) # Start all workers
+ # Allow the one successful worker to progress beyond the py_func again.
+ elem = int(math.sqrt(elem))
+ self.write_coordination_events[elem].set()
+ self.read_coordination_events[elem].acquire()
+ # Allow the prefetch to succeed
+ for i in range(4, 7):
+ self.read_coordination_events[i].acquire()
+ self.write_coordination_events[i].set()
+
+ def testTooManyReaders(self):
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
+ return dataset
+
+ dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
+ dataset = dataset.repeat(self.repeat_count)
+ dataset = dataset.apply(
+ sloppy_ops.sloppy_interleave,
+ args=(interleave_fn,),
+ kwargs={"cycle_length": 16,
+ "block_length": 2})
+ iterator = dataset.make_one_shot_iterator()
+
+ with self.test_session() as sess:
+ output_values = []
+ for _ in range(30):
+ output_values.append(sess.run(iterator.get_next()))
+
+ expected_values = self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
+ self.assertItemsEqual(output_values, expected_values)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 8afd122d82..94969c1c70 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -32,6 +32,21 @@ py_library(
],
)
+py_library(
+ name = "sloppy_ops",
+ srcs = ["sloppy_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ "//tensorflow/contrib/data/python/framework:function",
+ "//tensorflow/contrib/data/python/util:nest",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py
new file mode 100644
index 0000000000..010bd31161
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py
@@ -0,0 +1,120 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Non-deterministic dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.framework import function
+from tensorflow.contrib.data.python.ops import dataset_ops
+from tensorflow.contrib.data.python.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class SloppyInterleaveDataset(dataset_ops.Dataset):
+ """A `Dataset` that maps a function over its input and flattens the result."""
+
+ def __init__(self, input_dataset, map_func, cycle_length, block_length):
+ """See `tf.contrib.data.sloppy_interleave()` for details."""
+ super(SloppyInterleaveDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ @function.Defun(*nest.flatten(input_dataset.output_types))
+ def tf_map_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ # Pass in shape information from the input_dataset.
+ for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
+
+ if nest.is_sequence(nested_args):
+ dataset = map_func(*nested_args)
+ else:
+ dataset = map_func(nested_args)
+
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`map_func` must return a `Dataset` object.")
+
+ self._output_types = dataset.output_types
+ self._output_shapes = dataset.output_shapes
+
+ return dataset.make_dataset_resource()
+
+ self._map_func = tf_map_func
+ self._map_func.add_to_graph(ops.get_default_graph())
+
+ self._cycle_length = ops.convert_to_tensor(
+ cycle_length, dtype=dtypes.int64, name="cycle_length")
+ self._block_length = ops.convert_to_tensor(
+ block_length, dtype=dtypes.int64, name="block_length")
+
+ def make_dataset_resource(self):
+ return gen_dataset_ops.sloppy_interleave_dataset(
+ self._input_dataset.make_dataset_resource(),
+ self._map_func.captured_inputs,
+ self._cycle_length,
+ self._block_length,
+ f=self._map_func,
+ output_types=nest.flatten(self.output_types),
+ output_shapes=nest.flatten(self.output_shapes))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+def sloppy_interleave(dataset, map_func, cycle_length, block_length):
+ """Maps `map_func` across `dataset`, and interleaves the results.
+
+ The resulting dataset is almost identical to `interleave`. The key
+ difference being that if retrieving a value from a given output iterator would
+ cause `get_next` to block, that iterator will be skipped, and consumed
+ when next available. If consuming from all iterators would cause the
+ `get_next` call to block, the `get_next` call blocks until the first value is
+ available.
+
+ If the underlying datasets produce elements as fast as they are consumed, the
+ `sloppy_interleave` dataset behaves identically to the `interleave` dataset.
+ However, if an underlying dataset would block the consumer, the
+ `sloppy_interleave` dataset can violate to the round-robin order (respected by
+ the `interleave` dataset), producing an element from a different underlying
+ dataset instead.
+
+ WARNING: The order of elements in the resulting dataset is not
+ deterministic. Use `Dataset.interleave()` if you want the elements to have a
+ deterministic order.
+
+ Args:
+ dataset: A `Dataset` that produces elements to feed to `map_func`.
+ map_func: A function mapping a nested structure of tensors (having shapes
+ and types defined by `self.output_shapes` and `self.output_types`) to a
+ `Dataset`.
+ cycle_length: The number of threads to interleave from in parallel.
+ block_length: The number of consecutive elements to pull from a thread
+ before advancing to the next thread. Note: sloppy_interleave will
+ skip the remainder of elements in the block_length in order to avoid
+ blocking.
+
+ Returns:
+ A `Dataset`.
+ """
+ return SloppyInterleaveDataset(dataset, map_func, cycle_length, block_length)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8dd8900f28..32a1b2c84d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5588,6 +5588,20 @@ cc_library(
)
cc_library(
+ name = "dataset_utils",
+ srcs = ["dataset_utils.cc"],
+ hdrs = ["dataset_utils.h"],
+ deps = [
+ ":captured_function",
+ ":dataset",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/util/tensor_bundle",
+ ],
+)
+
+cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
hdrs = ["captured_function.h"],
@@ -5713,6 +5727,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -5727,6 +5742,22 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "sloppy_interleave_dataset_op",
+ srcs = ["sloppy_interleave_dataset_op.cc"],
+ deps = [
+ ":captured_function",
+ ":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -5963,6 +5994,7 @@ tf_kernel_library(
":repeat_dataset_op",
":shuffle_dataset_op",
":skip_dataset_op",
+ ":sloppy_interleave_dataset_op",
":sparse_tensor_slice_dataset_op",
":sql_dataset_ops",
":take_dataset_op",
diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc
new file mode 100644
index 0000000000..f320b3b09c
--- /dev/null
+++ b/tensorflow/core/kernels/dataset_utils.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/dataset_utils.h"
+
+namespace tensorflow {
+
+namespace dataset {
+
+Status MakeIteratorFromInputElement(
+ IteratorContext* ctx, const std::vector<Tensor>& input_element,
+ int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
+ std::unique_ptr<IteratorBase>* out_iterator) {
+ FunctionLibraryRuntime::Options opts;
+ opts.runner = ctx->runner();
+ // Choose a step ID that is guaranteed not to clash with any
+ // Session-generated step ID. DirectSession only generates
+ // non-negative step IDs (contiguous, starting from 0), and
+ // MasterSession generates 56-bit random step IDs whose MSB
+ // is always 0, so a negative random step ID should suffice.
+ opts.step_id = CapturedFunction::generate_step_id();
+ ScopedStepContainer step_container(
+ opts.step_id, [captured_func, ctx](const string& name) {
+ captured_func->resource_manager()->Cleanup(name).IgnoreError();
+ });
+ opts.step_container = &step_container;
+ std::vector<Tensor> return_values;
+ TF_RETURN_IF_ERROR(captured_func->Run(opts, input_element, &return_values));
+
+ if (!(return_values.size() == 1 && return_values[0].dtype() == DT_RESOURCE &&
+ TensorShapeUtils::IsScalar(return_values[0].shape()))) {
+ return errors::InvalidArgument(
+ "Function must return a single scalar of dtype DT_RESOURCE.");
+ }
+
+ // Retrieve the dataset that was created in `f`.
+ DatasetBase* returned_dataset;
+ const ResourceHandle& dataset_resource =
+ return_values[0].scalar<ResourceHandle>()();
+
+ // NOTE(mrry): We cannot use the core `LookupResource()` or
+ // `DeleteResource()` functions, because we have an
+ // `IteratorContext*` and not an `OpKernelContext*`, so we
+ // replicate the necessary functionality here.
+ auto type_index = MakeTypeIndex<DatasetBase>();
+ if (type_index.hash_code() != dataset_resource.hash_code()) {
+ return errors::InvalidArgument("Function must return a Dataset resource.");
+ }
+ TF_RETURN_IF_ERROR(captured_func->resource_manager()->Lookup(
+ dataset_resource.container(), dataset_resource.name(),
+ &returned_dataset));
+ core::ScopedUnref unref_dataset(returned_dataset);
+
+ // Create an iterator for the dataset that was returned by
+ // `f`. This transfers ownership of the dataset to the
+ // iterator, so we can delete it from the resource manager.
+ *out_iterator = returned_dataset->MakeIterator(
+ strings::StrCat(prefix, "[", thread_index, "]"));
+ TF_RETURN_IF_ERROR(captured_func->resource_manager()->Delete<DatasetBase>(
+ dataset_resource.container(), dataset_resource.name()));
+ return Status::OK();
+}
+
+} // namespace dataset
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dataset_utils.h b/tensorflow/core/kernels/dataset_utils.h
new file mode 100644
index 0000000000..eea2b8802b
--- /dev/null
+++ b/tensorflow/core/kernels/dataset_utils.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/captured_function.h"
+#include "tensorflow/core/kernels/dataset.h"
+
+namespace tensorflow {
+
+namespace dataset {
+
+Status MakeIteratorFromInputElement(
+ IteratorContext* ctx, const std::vector<Tensor>& input_element,
+ int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
+ std::unique_ptr<IteratorBase>* out_iterator);
+
+} // namespace dataset
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc
index e2310fecc7..a87e54bf31 100644
--- a/tensorflow/core/kernels/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/flat_map_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/kernels/captured_function.h"
+#include "tensorflow/core/kernels/dataset_utils.h"
namespace tensorflow {
@@ -125,58 +126,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- FunctionLibraryRuntime::Options opts;
- opts.runner = ctx->runner();
- opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
- dataset()
- ->captured_func_->resource_manager()
- ->Cleanup(name)
- .IgnoreError();
- });
- opts.step_container = &step_container;
- std::vector<Tensor> return_values;
- TF_RETURN_IF_ERROR(
- dataset()->captured_func_->Run(opts, args, &return_values));
-
- if (!(return_values.size() == 1 &&
- return_values[0].dtype() == DT_RESOURCE &&
- TensorShapeUtils::IsScalar(return_values[0].shape()))) {
- return errors::InvalidArgument(
- "`f` must return a single scalar of dtype DT_RESOURCE.");
- }
-
- // Retrieve the dataset that was created in `f`.
- DatasetBase* returned_dataset;
- const ResourceHandle& dataset_resource =
- return_values[0].scalar<ResourceHandle>()();
-
- // NOTE(mrry): We cannot use the core `LookupResource()` or
- // `DeleteResource()` functions, because we have an
- // `IteratorContext*` and not an `OpKernelContext*`, so we
- // replicate the necessary functionality here.
- auto type_index = MakeTypeIndex<DatasetBase>();
- if (type_index.hash_code() != dataset_resource.hash_code()) {
- return errors::InvalidArgument(
- "`f` must return a Dataset resource.");
- }
- TF_RETURN_IF_ERROR(
- dataset()->captured_func_->resource_manager()->Lookup(
- dataset_resource.container(), dataset_resource.name(),
- &returned_dataset));
- core::ScopedUnref unref_dataset(returned_dataset);
-
- // Create an iterator for the dataset that was returned by
- // `f`. This transfers ownership of the dataset to the
- // iterator, so we can delete it from the resource manager.
- current_element_iterator_ = returned_dataset->MakeIterator(
- strings::StrCat(prefix(), "[", element_index_++, "]"));
- TF_RETURN_IF_ERROR(
- dataset()
- ->captured_func_->resource_manager()
- ->Delete<DatasetBase>(dataset_resource.container(),
- dataset_resource.name()));
+ TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ ctx, args, element_index_++, dataset()->captured_func_.get(),
+ prefix(), &current_element_iterator_));
} while (true);
}
diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc
index dce4f88101..7b148b74c9 100644
--- a/tensorflow/core/kernels/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/interleave_dataset_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/kernels/captured_function.h"
+#include "tensorflow/core/kernels/dataset_utils.h"
namespace tensorflow {
@@ -168,8 +169,9 @@ class InterleaveDatasetOp : public OpKernel {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &args, &end_of_input_));
if (!end_of_input_) {
- TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
- ctx, args, &current_elements_[cycle_index_]));
+ TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ ctx, args, cycle_index_, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[cycle_index_]));
++num_open_;
}
} else {
@@ -182,62 +184,6 @@ class InterleaveDatasetOp : public OpKernel {
}
private:
- Status MakeIteratorFromInputElement(
- IteratorContext* ctx, const std::vector<Tensor>& input_element,
- std::unique_ptr<IteratorBase>* out_iterator)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- FunctionLibraryRuntime::Options opts;
- opts.runner = ctx->runner();
- opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
- dataset()
- ->captured_func_->resource_manager()
- ->Cleanup(name)
- .IgnoreError();
- });
- opts.step_container = &step_container;
- std::vector<Tensor> return_values;
- TF_RETURN_IF_ERROR(dataset()->captured_func_->Run(opts, input_element,
- &return_values));
-
- if (!(return_values.size() == 1 &&
- return_values[0].dtype() == DT_RESOURCE &&
- TensorShapeUtils::IsScalar(return_values[0].shape()))) {
- return errors::InvalidArgument(
- "`f` must return a single scalar of dtype DT_RESOURCE.");
- }
-
- // Retrieve the dataset that was created in `f`.
- DatasetBase* returned_dataset;
- const ResourceHandle& dataset_resource =
- return_values[0].scalar<ResourceHandle>()();
-
- // NOTE(mrry): We cannot use the core `LookupResource()` or
- // `DeleteResource()` functions, because we have an
- // `IteratorContext*` and not an `OpKernelContext*`, so we
- // replicate the necessary functionality here.
- auto type_index = MakeTypeIndex<DatasetBase>();
- if (type_index.hash_code() != dataset_resource.hash_code()) {
- return errors::InvalidArgument("`f` must return a Dataset resource.");
- }
- TF_RETURN_IF_ERROR(
- dataset()->captured_func_->resource_manager()->Lookup(
- dataset_resource.container(), dataset_resource.name(),
- &returned_dataset));
- core::ScopedUnref unref_dataset(returned_dataset);
-
- // Create an iterator for the dataset that was returned by
- // `f`. This transfers ownership of the dataset to the
- // iterator, so we can delete it from the resource manager.
- *out_iterator = returned_dataset->MakeIterator(
- strings::StrCat(prefix(), "[", cycle_index_, "]"));
- TF_RETURN_IF_ERROR(
- dataset()->captured_func_->resource_manager()->Delete<DatasetBase>(
- dataset_resource.container(), dataset_resource.name()));
- return Status::OK();
- }
-
mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> current_elements_
diff --git a/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc
new file mode 100644
index 0000000000..d95f51f0f2
--- /dev/null
+++ b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc
@@ -0,0 +1,370 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/dataset.h"
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/dataset_utils.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+
+#include "tensorflow/core/kernels/captured_function.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OpInputList inputs;
+ OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
+ std::vector<Tensor> other_arguments;
+ other_arguments.reserve(inputs.size());
+ for (const Tensor& t : inputs) {
+ other_arguments.push_back(t);
+ }
+
+ int64 cycle_length;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
+ std::move(other_arguments),
+ &captured_func));
+
+ *output = new Dataset(input, std::move(captured_func), cycle_length,
+ block_length, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(const DatasetBase* input,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : input_(input),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() override {
+ return "SloppyInterleaveDatasetOp::Dataset";
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
+ output_elements_(params.dataset->cycle_length_) {}
+
+ ~Iterator() override {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ // Notify all workers in case they are blocked.
+ for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
+ output_elements_[i].cond_var.notify_all();
+ }
+ }
+
+ // It is implemented so that it matches the deterministic interleave
+ // unless we would block waiting for an element, at which point it skips
+ // along to the next available value.
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
+ // Search for available items, blocking if necessary.
+ while (!cancelled_) {
+ for (size_t i = 0; i < dataset()->cycle_length_; ++i) {
+ size_t index = (next_index_ + i) % dataset()->cycle_length_;
+ if (output_elements_[index].is_produced) {
+ next_index_ = index;
+ if (i == 0) {
+ block_count_++;
+ if (block_count_ == dataset()->block_length_) {
+ next_index_ = (index + 1) % dataset()->cycle_length_;
+ block_count_ = 0;
+ }
+ } else {
+ block_count_ = 0;
+ }
+ // If we encounter an EoF, advance to the next iterator
+ if (output_elements_[index].end_of_sequence) {
+ output_elements_[index].is_produced = false;
+ output_elements_[index].cond_var.notify_one();
+ next_index_ = (index + 1) % dataset()->cycle_length_;
+ block_count_ = 0;
+ i = -1; // Restart the inner loop
+ continue;
+ }
+ *end_of_sequence = false;
+ if (output_elements_[index].output_status.ok()) {
+ output_elements_[index].output_value.swap(*out_tensors);
+ }
+ output_elements_[index].is_produced = false;
+ output_elements_[index].cond_var.notify_one();
+ return output_elements_[index].output_status;
+ }
+ }
+
+ if (num_active_threads_ == 0) {
+ // No potential for future values.
+ //
+ // Note: this condition check must occur after checking the output
+ // buffer, as its possible for there to be values in the output
+ // buffer, even if the number of live threads is zero.
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ // No values available; wait until woken up.
+ cond_var_.wait(l);
+ }
+ return errors::Cancelled(
+ "SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext");
+ }
+
+ private:
+ // Internal structure to manage thread coordination. All values are
+ // guarded by the enclosing Iterator's mu_.
+ struct OutputBufferElement {
+ // The producer must set `is_produced` to `true` after
+ // `output_status` or `output_value` has been written.
+ bool is_produced = false;
+ // The producer sets `output_status` if either getting the input element
+ // or applying the function to it fails.
+ Status output_status;
+ // Reached end of sequence for the underlying iterator.
+ bool end_of_sequence = false;
+ // The output data element.
+ std::vector<Tensor> output_value;
+ // The producer thread waits on this condition variable after having
+ // produced an element. The reader thread notifies this condition
+ // variable after reading the value.
+ condition_variable cond_var;
+ };
+
+ Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (worker_threads_.empty()) {
+ for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
+ // Serialize the creation of the workers and their corresponding
+ // input elements to ensure we match the standard interleave when
+ // the underlying iterators induce no delay.
+ std::vector<Tensor> args;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &args, &end_of_input_));
+ if (end_of_input_) {
+ LOG(WARNING) << "Input iterator exhausted after " << i
+ << " elements; cannot start all "
+ << dataset()->cycle_length_ << " worker threads.";
+ return Status::OK();
+ }
+ std::unique_ptr<IteratorBase> itr;
+ TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
+ worker_threads_.emplace_back(
+ std::unique_ptr<Thread>(ctx->env()->StartThread(
+ {}, "worker_thread",
+ std::bind(&Iterator::WorkerThread, this,
+ new IteratorContext(*ctx), i, itr.release()))));
+ num_active_threads_ = i + 1;
+ }
+ }
+ return Status::OK();
+ }
+
+ void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index,
+ const Status& status,
+ bool end_of_sequence,
+ std::vector<Tensor>* out_tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // We have produced an element; push it into the output buffer
+ // when space is available.
+ while (!cancelled_ && output_elements_[thread_index].is_produced) {
+ output_elements_[thread_index].cond_var.wait(*l);
+ }
+ if (cancelled_) {
+ return;
+ }
+ output_elements_[thread_index].is_produced = true;
+ output_elements_[thread_index].output_status = status;
+ output_elements_[thread_index].end_of_sequence = end_of_sequence;
+ if (status.ok()) {
+ output_elements_[thread_index].output_value.swap(*out_tensors);
+ } else {
+ output_elements_[thread_index].output_value.clear();
+ }
+ cond_var_.notify_one();
+ }
+
+ // Races to produce elements into the output queue buffers.
+ void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index,
+ IteratorBase* out_iterator_ptr) {
+ // std::function arguments are copy-constructable, so we pass raw
+ // pointers, and then immediately wrap them to ensure correct ownership.
+ std::unique_ptr<IteratorContext> ctx(ctx_ptr);
+ std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
+ auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ mutex_lock l(mu_);
+ num_active_threads_--;
+ cond_var_.notify_all();
+ });
+ while (true) {
+ // Attempt to produce an element.
+ bool end_of_out_itr_input = false;
+ std::vector<Tensor> out_tensors;
+ Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors,
+ &end_of_out_itr_input);
+ // Handle output.
+ {
+ mutex_lock l(mu_);
+ BlockAndUpdateOutputBuffer(&l, thread_index, element_status,
+ end_of_out_itr_input, &out_tensors);
+ if (end_of_out_itr_input) {
+ // We have exhausted our current iterator; get a new iterator;
+ // loop to handle errors.
+ while (!cancelled_) {
+ if (end_of_input_) {
+ // No more iterator inputs; we're done!
+ return;
+ }
+ std::vector<Tensor> args;
+ // BlockAndUpdateOutputBuffer() sequences calls to
+ // input_impl_->GetNext when the out_iterator doesn't cause
+ // slopping.
+ Status input_status =
+ input_impl_->GetNext(ctx.get(), &args, &end_of_input_);
+ if (end_of_input_) {
+ // No more elements to produce, stop the worker thread.
+ return;
+ }
+ if (input_status.ok()) {
+ input_status = dataset::MakeIteratorFromInputElement(
+ ctx.get(), args, thread_index,
+ dataset()->captured_func_.get(), prefix(), &out_iterator);
+ }
+ if (input_status.ok()) {
+ // Successfully have a new out_iterator; restart the outer
+ // loop to produce an element.
+ break;
+ }
+
+ // We encountered an error; push the error to the output buffer.
+ BlockAndUpdateOutputBuffer(&l, thread_index, input_status,
+ /* end_of_sequence = */ false,
+ &out_tensors);
+ }
+ }
+
+ // Check if we should exit.
+ if (cancelled_) {
+ return;
+ }
+ }
+ }
+ }
+
+ // Mutex & condition variable to guard mutable iterator internals and
+ // coordinate among worker threads and client thread[s].
+ mutex mu_;
+ condition_variable cond_var_;
+ // The iterator producing elements which are converted to datasets by
+ // the dataset()->captured_func_ then interleaved together.
+ const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ // Whether the input_impl_ can produce future elements.
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+ // The buffer of elements to be produced. Each worker thread operates
+ // on a single OutputBufferElement.
+ std::vector<OutputBufferElement> output_elements_ GUARDED_BY(mu_);
+ // The index into output_elements_ for next element to produce.
+ size_t next_index_ GUARDED_BY(mu_) = 0;
+ // The number of items produced so far within the block
+ size_t block_count_ GUARDED_BY(mu_) = 0;
+ // Number of active threads.
+ size_t num_active_threads_ GUARDED_BY(mu_) = 0;
+ // Flag to instruct the worker threads to exit.
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ // Pointers to the worker threads. This must be last to ensure the
+ // threads have exited before any other members are deallocated.
+ // TODO(b/65178177): Avoid allocating additional threads.
+ std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const NameAttrList* func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU),
+ SloppyInterleaveDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 37d9a737e2..7cc8dccb95 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -233,6 +233,33 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`.
)doc");
+REGISTER_OP("SloppyInterleaveDataset")
+ .Input("input_dataset: resource")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Output("handle: resource")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+ `other_arguments`, to a Dataset resource that contains elements matching
+ `output_types` and `output_shapes`.
+)doc");
+
REGISTER_OP("GroupByWindowDataset")
.Input("input_dataset: resource")
.Input("key_func_other_arguments: Tkey_func_other_arguments")