aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-02-01 11:50:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:15:09 -0800
commit77e6a452188e83ae4498cc3ae23e20e60061b367 (patch)
treeb0156176803ee95465bd456f65718aa407954361 /tensorflow/python/data
parent997c209f9b8210f4bdc44a0172e0b64f5f7761c0 (diff)
[tf.data] Fix bug where captured resources in shared iterators were invisible.
This change ensures that a shared iterator (which requires a private FunctionLibraryRuntime that outlasts the calling op's runtime, because it can outlive a single session) uses the same Device as a non-shared iterator, and hence capturing resources from the creating graph will work as intended. Fixes #16481. PiperOrigin-RevId: 184172498
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py37
2 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 43cbde69d9..8b8adefa65 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -357,6 +357,9 @@ tf_py_test(
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
],
grpc_enabled = True,
tags = [
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
index 45dfa13720..2c65c49ebd 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
@@ -21,6 +21,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@@ -28,6 +29,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
@@ -103,6 +106,40 @@ class IteratorClusterTest(test.TestCase):
"/job:worker/replica:0/task:1/cpu:0",
workers[0].target)
+ def testCaptureHashTableInSharedIterator(self):
+ worker, _ = test_util.create_local_cluster(1, 1)
+
+ # NOTE(mrry): We must use the V2 variants of `HashTable`
+ # etc. because these produce a `tf.resource`-typed output that is
+ # compatible with the in-graph function implementation.
+ default_val = -1
+ keys = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer(keys, values),
+ default_val,
+ shared_name="shared_table")
+
+ input_sentences = dataset_ops.Dataset.from_tensor_slices(
+ ["brain brain tank salad surgery", "surgery brain"])
+
+ iterator = (
+ input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
+ table.lookup)
+ .make_initializable_iterator(shared_name="shared_iterator"))
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with session.Session(worker[0].target) as sess:
+ sess.run(table.init)
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
+
+ with session.Session(worker[0].target) as sess:
+ self.assertAllEqual([2, 0], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()