diff options
author | 2018-02-01 11:50:23 -0800 | |
---|---|---|
committer | 2018-02-01 17:15:09 -0800 | |
commit | 77e6a452188e83ae4498cc3ae23e20e60061b367 (patch) | |
tree | b0156176803ee95465bd456f65718aa407954361 /tensorflow/python/data | |
parent | 997c209f9b8210f4bdc44a0172e0b64f5f7761c0 (diff) |
[tf.data] Fix bug where captured resources in shared iterators were invisible.
This change ensures that a shared iterator (which requires a private
FunctionLibraryRuntime that outlasts the calling op's runtime, because
it can outlive a single session) uses the same Device as a non-shared
iterator, and hence capturing resources from the creating graph will
work as intended.
Fixes #16481.
PiperOrigin-RevId: 184172498
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py | 37 |
2 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 43cbde69d9..8b8adefa65 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -357,6 +357,9 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", ], grpc_enabled = True, tags = [ diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 45dfa13720..2c65c49ebd 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -21,6 +21,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -28,6 +29,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -103,6 +106,40 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:1/cpu:0", workers[0].target) + def testCaptureHashTableInSharedIterator(self): + worker, _ = test_util.create_local_cluster(1, 1) + + # NOTE(mrry): We must use the V2 variants of `HashTable` + # etc. because these produce a `tf.resource`-typed output that is + # compatible with the in-graph function implementation. + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), + default_val, + shared_name="shared_table") + + input_sentences = dataset_ops.Dataset.from_tensor_slices( + ["brain brain tank salad surgery", "surgery brain"]) + + iterator = ( + input_sentences.map(lambda x: string_ops.string_split([x]).values).map( + table.lookup) + .make_initializable_iterator(shared_name="shared_iterator")) + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(table.init) + sess.run(init_op) + self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) + + with session.Session(worker[0].target) as sess: + self.assertAllEqual([2, 0], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() |