diff options
author | Derek Murray <mrry@google.com> | 2018-08-22 15:12:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 15:26:26 -0700 |
commit | 937c279fd9758ce75c43ad8d2d828475b496334a (patch) | |
tree | f8ee996baa5e0ee22d84ed72aa8d63e235bb62e4 | |
parent | d97e525c1fa601b4b1b0a3cffd9e579201d167ef (diff) |
[tf.data] Fix MultiDeviceIterator initialization to use correct FLR.
In the recent change to instantiate functions at iterator
initialization time, I forgot to make the corresponding change was to
MultiDeviceIterator. This change fixes that problem and unbreaks
MultiDeviceIterator.
PiperOrigin-RevId: 209838406
-rw-r--r-- | tensorflow/contrib/data/kernels/prefetching_kernels.cc | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 74df1e42a8..725f8933c9 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -548,7 +548,9 @@ class MultiDeviceIterator : public ResourceBase { devices_(devices), flib_def_(std::move(flib_def)), pflr_(std::move(pflr)), - lib_(lib) {} + lib_(lib) { + CHECK_NOTNULL(lib_); + } string DebugString() override { return strings::StrCat("MultiDeviceIterator for ", devices_.size(), @@ -600,6 +602,11 @@ class MultiDeviceIterator : public ResourceBase { return lib_def_; } + FunctionLibraryRuntime* const lib() { + tf_shared_lock l(mu_); + return lib_; + } + private: // A private class that uses a background thread to keep a per device buffer // full. @@ -930,8 +937,10 @@ class MultiDeviceIteratorInitOp : public OpKernel { core::ScopedUnref unref(resource); std::unique_ptr<IteratorBase> iterator; - OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", - &iterator)); + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(resource->lib()); + OP_REQUIRES_OK( + ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); int64 incarnation_id; OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, &incarnation_id)); |