aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-21 11:48:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 11:59:01 -0700
commit9158b1b83a0128fc41bfccd80fe26d8231fe958b (patch)
tree4bb719d17c4284a5b07814ff24113fa5d0b446f0
parente28f9da84b51acdbf3234688daa4c55647041219 (diff)
[tf.data] Move captured function instantiation to iterator initialization time.
Previously, a function instantiation error (e.g. in `Dataset.map()`) would lead to an error in each GetNext() call that attempted to use the function. Moving this to iterator instantiation time has the benefit that the error will be reported once when the initialization op is executed, which has a more helpful stack trace, since it should not be conflated with other potential op failures. PiperOrigin-RevId: 209633511
-rw-r--r--tensorflow/core/common_runtime/function.cc6
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc65
-rw-r--r--tensorflow/core/kernels/data/captured_function.h4
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc20
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.h1
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc19
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc28
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h10
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc4
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py32
22 files changed, 210 insertions, 81 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 54bbe84b57..fb89bcc0df 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
next_handle_++;
}
}
+
+ if (options.create_kernels_eagerly) {
+ Item* item;
+ TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index edb7ed01e9..a2e69a152a 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -490,6 +490,11 @@ class FunctionLibraryRuntime {
// Instantiates the function using an executor of the given type. If empty,
// the default TensorFlow executor will be used.
string executor_type;
+
+ // If true, the runtime will attempt to create kernels for the function at
+ // instantiation time, rather than on the first run. This can be used to
+ // surface errors earlier.
+ bool create_kernels_eagerly = false;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 82da385405..abdf6ee4e8 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -172,31 +172,17 @@ class BorrowedArgsCallFrame : public CallFrameBase {
} // namespace
-Status CapturedFunction::MaybeInstantiate(
- IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
- mutex_lock l(mu_);
+Status CapturedFunction::GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle) {
+ tf_shared_lock l(mu_);
if (lib_ == nullptr) {
- // The context's runtime will be used for all subsequent calls.
- lib_ = ctx->lib();
- DCHECK(f_handle_ == kInvalidHandle);
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
- inst_opts.state_handle = std::to_string(random::New64());
- TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
- inst_opts, &f_handle_));
- const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
- if (fbody == nullptr) {
- return errors::Internal("Failed to instantiate function body.");
- }
- ret_types_ = fbody->ret_types;
- } else {
- // TODO(mrry): Consider moving this under a shared lock, as it is
- // the common case.
- if (ctx->lib() != lib_) {
- return errors::Internal(
- "Captured function was called with a different "
- "FunctionLibraryRuntime*, which is not permitted.");
- }
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called before it was instantiated.");
+ }
+ if (ctx->lib() != lib_) {
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
}
*out_handle = f_handle_;
return Status::OK();
@@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate(
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
}
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
- FunctionLibraryRuntime::Handle unused_handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
mutex_lock l(mu_);
+ if (lib_ == nullptr) {
+ // The context's runtime will be used for all subsequent calls.
+ lib_ = ctx->lib();
+ DCHECK(f_handle_ == kInvalidHandle);
+ FunctionLibraryRuntime::InstantiateOptions inst_opts;
+ inst_opts.overlay_lib = ctx->function_library().get();
+ inst_opts.state_handle = std::to_string(random::New64());
+ inst_opts.create_kernels_eagerly = true;
+ Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
+ inst_opts, &f_handle_));
+ TF_RETURN_IF_ERROR(s);
+ const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
+ if (fbody == nullptr) {
+ return errors::Internal("Failed to instantiate function body.");
+ }
+ ret_types_ = fbody->ret_types;
+ } else {
+ if (ctx->lib() != lib_) {
+ return errors::Internal(
+ "Captured function was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
+ }
+ }
if (captured_runner_ == nullptr) {
captured_runner_ = *ctx->runner();
}
@@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
FunctionLibraryRuntime::Handle handle;
- Status s = MaybeInstantiate(ctx, &handle);
+ Status s = GetHandle(ctx, &handle);
if (!s.ok()) {
done(s);
return;
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e9ad3e381d..c95f2b1c01 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -116,8 +116,8 @@ class CapturedFunction {
CapturedFunction(const NameAttrList& func,
std::vector<Tensor> captured_inputs);
- Status MaybeInstantiate(IteratorContext* ctx,
- FunctionLibraryRuntime::Handle* out_handle);
+ Status GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle);
mutex mu_;
const NameAttrList func_;
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index a80e102ccf..f5c7d336a6 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<FilterDatasetBase>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 07bcb9d414..21e627a8e8 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 3c3d78b724..ccee690d7e 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
@@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ return Status::OK();
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (!initialized_) {
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
- // Explicitly instantiate the finalize function here so that
- // we can invoke it in the destructor.
- TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- initialized_ = true;
- }
-
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
- bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
index 3f84fa9c2e..8407543136 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.h
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index be4132a064..4a388645f2 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_finalize_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 288695f3cd..f993a68934 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_window_size_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 58b79d6026..6bba667759 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -1,4 +1,3 @@
-
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 61a6c06135..25beb02f0e 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase {
bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
+ CHECK_NOTNULL(lib_);
+ ctx->set_lib(lib_);
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
} else {
return errors::FailedPrecondition(
@@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase {
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
- dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase {
return lib_def_;
}
+ FunctionLibraryRuntime* function_library_runtime() { return lib_; }
+
// Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) {
@@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
core::ScopedUnref unref(iterator_resource);
std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(iterator_resource->function_library_runtime());
OP_REQUIRES_OK(
- ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
@@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
std::unique_ptr<IteratorBase> iter;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
- dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 0e17011b05..c4df7f2756 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 294fb1c49a..26ae26a7fd 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index b097598cd9..b2d307ba8a 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
- &input_impl_);
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = ctx->lib();
+ params.function_library = dataset()->flib_def_;
+ params.allocator_getter = ctx->allocator_getter();
+ return dataset()->optimized_input_->MakeIterator(
+ IteratorContext(params), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index cfa96d910d..bf86361a71 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
// It is implemented so that it matches the deterministic interleave
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a407abfce4..e03a4e353b 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
+ auto init_func = [this](IteratorContext* ctx) {
+ return captured_func_->Instantiate(ctx);
+ };
+
auto map_func = [this](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
@@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return NewParallelMapIterator(
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(map_func), num_parallel_calls_);
+ std::move(init_func), std::move(map_func), num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4d32b719a4..61f8139b9e 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params,
- const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
- int32 num_parallel_calls)
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params),
input_dataset_(input_dataset),
+ init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
num_parallel_calls_(num_parallel_calls) {}
@@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
- return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
+ if (init_func_) {
+ TF_RETURN_IF_ERROR(init_func_(ctx));
+ }
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
@@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
const DatasetBase* const input_dataset_; // Not owned.
+ const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
@@ -311,8 +319,18 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
- params, input_dataset, std::move(map_func), num_parallel_calls));
+ return NewParallelMapIterator(params, input_dataset, nullptr,
+ std::move(map_func), num_parallel_calls);
+}
+
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
+ return std::unique_ptr<IteratorBase>(
+ new ParallelMapIterator(params, input_dataset, std::move(init_func),
+ std::move(map_func), num_parallel_calls));
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 2ce36c3869..7e6cc586f3 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -33,7 +33,15 @@ using ParallelMapIteratorFunction =
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
-// `input_dataset` using the given degree of parallelism.
+// `input_dataset` using the given degree of parallelism. `init_func` (if
+// specified) will be executed when the iterator is initialized (see
+// `IteratorBase::Initialize()`) and enables the user to specify error checking
+// logic that can fail early.
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 5e9ace3486..299949b99f 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class ForeverIterator : public DatasetIterator<Dataset> {
public:
explicit ForeverIterator(const Params& params)
- : DatasetIterator<Dataset>(params), input_impl_(nullptr) {}
+ : DatasetIterator<Dataset>(params),
+ input_impl_(nullptr),
+ first_call_(true) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
do {
- bool first_call = false;
if (!input_impl_) {
- first_call = true;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (!*end_of_sequence) {
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ if (first_call_ && *end_of_sequence) {
+ // If the first call to GetNext() fails because the end
+ // of sequence has been reached, we terminate the
+ // iteration immediately. (Otherwise, this iterator
+ // would loop infinitely and never produce a value.)
+ input_impl_.reset();
return Status::OK();
+ }
+ first_call_ = false;
+ if (!*end_of_sequence) {
+ return s;
} else {
input_impl_.reset();
- if (first_call) {
- // If the first call to GetNext() fails because the end
- // of sequence has been reached, we terminate the
- // iteration immediately. (Otherwise, this iterator
- // would loop infinitely and never produce a value.)
- return Status::OK();
- }
+ first_call_ = true;
}
} while (true);
}
@@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (input_impl_)
+ if (!first_call_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
else
TF_RETURN_IF_ERROR(
@@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
+ first_call_ = true;
} else {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ first_call_ = false;
}
return Status::OK();
}
@@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool first_call_ GUARDED_BY(mu_);
};
const int64 count_;
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index e4cb31e2b2..5d3319b19f 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_(params.dataset->initial_state_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 637bde9ae4..52b4320bf1 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -24,6 +24,7 @@ import warnings
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
@@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase):
r"Dataset.map\(\): None."):
_ = dataset.map(lambda x: None)
+ def testBrokenFunctionErrorOnInitialization(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
+
+ def broken_function(_):
+ """A function deliberately designed to fail on instantiation."""
+ value = []
+ tensor_value = attr_value_pb2.AttrValue()
+ tensor_value.tensor.CopyFrom(
+ tensor_util.make_tensor_proto(
+ value, dtype=dtypes.float32, shape=[0], verify_shape=False))
+ dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
+
+ # Create a "Const" op with a `tf.float32` value and a `tf.int32` type
+ # attr.
+ const_tensor = ops.get_default_graph().create_op(
+ "Const", [], [dtypes.int32],
+ attrs={
+ "value": tensor_value,
+ "dtype": dtype_value
+ },
+ name="BrokenConst").outputs[0]
+ return const_tensor
+
+ dataset = dataset.map(broken_function)
+ iterator = dataset.make_initializable_iterator()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
+ sess.run(iterator.initializer)
+
class MapDatasetBenchmark(test.Benchmark):