diff options
author | Derek Murray <mrry@google.com> | 2018-02-01 11:50:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-01 17:15:09 -0800 |
commit | 77e6a452188e83ae4498cc3ae23e20e60061b367 (patch) | |
tree | b0156176803ee95465bd456f65718aa407954361 | |
parent | 997c209f9b8210f4bdc44a0172e0b64f5f7761c0 (diff) |
[tf.data] Fix bug where captured resources in shared iterators were invisible.
This change ensures that a shared iterator (which requires a private
FunctionLibraryRuntime that outlasts the calling op's runtime, because
it can outlive a single session) uses the same Device as a non-shared
iterator, and hence capturing resources from the creating graph will
work as intended.
Fixes #16481.
PiperOrigin-RevId: 184172498
-rw-r--r-- | tensorflow/core/kernels/data/iterator_ops.cc | 30 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py | 37 |
3 files changed, 64 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index b37bd672ad..dd5f4a4554 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -516,15 +517,32 @@ class IteratorHandleOp : public OpKernel { return Status::OK(); } + template <typename To, typename From> // use like this: down_cast<T*>(foo); + static inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast<To>(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + return static_cast<To>(f); + } + FunctionLibraryRuntime* CreatePrivateFLR( OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, std::unique_ptr<FunctionLibraryDefinition>* flib_def, std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) { - Device* device = new ThreadPoolDevice( - SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20), - DeviceLocality(), cpu_allocator()); - - device_mgr->reset(new DeviceMgr({device})); + // Wrap the existing device in order to see any captured resources + // in its resource manager. The existing device will outlive the + // IteratorResource, because we are storing the IteratorResource + // in that device's resourc manager. + Device* wrapped_device = RenamedDevice::NewRenamedDevice( + ctx->device()->name(), down_cast<Device*>(ctx->device()), + false /* owns_underlying */, false /* isolate_session_state */); + device_mgr->reset(new DeviceMgr({wrapped_device})); flib_def->reset(new FunctionLibraryDefinition( *ctx->function_library()->GetFunctionLibraryDefinition())); pflr->reset(new ProcessFunctionLibraryRuntime( @@ -532,7 +550,7 @@ class IteratorHandleOp : public OpKernel { {} /* TODO(mrry): OptimizerOptions? */, nullptr /* TODO(mrry): ClusterFLR */)); - return (*pflr)->GetFLR(device->name()); + return (*pflr)->GetFLR(ctx->device()->name()); } mutex mu_; diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 43cbde69d9..8b8adefa65 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -357,6 +357,9 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", ], grpc_enabled = True, tags = [ diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 45dfa13720..2c65c49ebd 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -21,6 +21,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -28,6 +29,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -103,6 +106,40 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:1/cpu:0", workers[0].target) + def testCaptureHashTableInSharedIterator(self): + worker, _ = test_util.create_local_cluster(1, 1) + + # NOTE(mrry): We must use the V2 variants of `HashTable` + # etc. because these produce a `tf.resource`-typed output that is + # compatible with the in-graph function implementation. + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), + default_val, + shared_name="shared_table") + + input_sentences = dataset_ops.Dataset.from_tensor_slices( + ["brain brain tank salad surgery", "surgery brain"]) + + iterator = ( + input_sentences.map(lambda x: string_ops.string_split([x]).values).map( + table.lookup) + .make_initializable_iterator(shared_name="shared_iterator")) + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(table.init) + sess.run(init_op) + self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) + + with session.Session(worker[0].target) as sess: + self.assertAllEqual([2, 0], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() |