aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-02-01 11:50:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:15:09 -0800
commit77e6a452188e83ae4498cc3ae23e20e60061b367 (patch)
treeb0156176803ee95465bd456f65718aa407954361
parent997c209f9b8210f4bdc44a0172e0b64f5f7761c0 (diff)
[tf.data] Fix bug where captured resources in shared iterators were invisible.
This change ensures that a shared iterator (which requires a private FunctionLibraryRuntime that outlasts the calling op's runtime, because it can outlive a single session) uses the same Device as a non-shared iterator, and hence capturing resources from the creating graph will work as intended. Fixes #16481. PiperOrigin-RevId: 184172498
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc30
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py37
3 files changed, 64 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index b37bd672ad..dd5f4a4554 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -516,15 +517,32 @@ class IteratorHandleOp : public OpKernel {
return Status::OK();
}
+ template <typename To, typename From> // use like this: down_cast<T*>(foo);
+ static inline To down_cast(From* f) { // so we only accept pointers
+ static_assert(
+ (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
+ "target type not derived from source type");
+
+ // We skip the assert and hence the dynamic_cast if RTTI is disabled.
+#if !defined(__GNUC__) || defined(__GXX_RTTI)
+ // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
+ assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
+#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
+ return static_cast<To>(f);
+ }
+
FunctionLibraryRuntime* CreatePrivateFLR(
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
- Device* device = new ThreadPoolDevice(
- SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20),
- DeviceLocality(), cpu_allocator());
-
- device_mgr->reset(new DeviceMgr({device}));
+ // Wrap the existing device in order to see any captured resources
+ // in its resource manager. The existing device will outlive the
+ // IteratorResource, because we are storing the IteratorResource
+ // in that device's resourc manager.
+ Device* wrapped_device = RenamedDevice::NewRenamedDevice(
+ ctx->device()->name(), down_cast<Device*>(ctx->device()),
+ false /* owns_underlying */, false /* isolate_session_state */);
+ device_mgr->reset(new DeviceMgr({wrapped_device}));
flib_def->reset(new FunctionLibraryDefinition(
*ctx->function_library()->GetFunctionLibraryDefinition()));
pflr->reset(new ProcessFunctionLibraryRuntime(
@@ -532,7 +550,7 @@ class IteratorHandleOp : public OpKernel {
{} /* TODO(mrry): OptimizerOptions? */,
nullptr /* TODO(mrry): ClusterFLR */));
- return (*pflr)->GetFLR(device->name());
+ return (*pflr)->GetFLR(ctx->device()->name());
}
mutex mu_;
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 43cbde69d9..8b8adefa65 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -357,6 +357,9 @@ tf_py_test(
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
],
grpc_enabled = True,
tags = [
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
index 45dfa13720..2c65c49ebd 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
@@ -21,6 +21,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@@ -28,6 +29,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
@@ -103,6 +106,40 @@ class IteratorClusterTest(test.TestCase):
"/job:worker/replica:0/task:1/cpu:0",
workers[0].target)
+ def testCaptureHashTableInSharedIterator(self):
+ worker, _ = test_util.create_local_cluster(1, 1)
+
+ # NOTE(mrry): We must use the V2 variants of `HashTable`
+ # etc. because these produce a `tf.resource`-typed output that is
+ # compatible with the in-graph function implementation.
+ default_val = -1
+ keys = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer(keys, values),
+ default_val,
+ shared_name="shared_table")
+
+ input_sentences = dataset_ops.Dataset.from_tensor_slices(
+ ["brain brain tank salad surgery", "surgery brain"])
+
+ iterator = (
+ input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
+ table.lookup)
+ .make_initializable_iterator(shared_name="shared_iterator"))
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with session.Session(worker[0].target) as sess:
+ sess.run(table.init)
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
+
+ with session.Session(worker[0].target) as sess:
+ self.assertAllEqual([2, 0], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()