aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/skip_dataset_op.cc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-05-31 13:43:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 13:45:55 -0700
commit89a55fef3316e0e270e0f87f71bd8c2d32443cc8 (patch)
tree20825ff4e2a98e83b3781c86751bc5d385d4f9fe /tensorflow/core/kernels/data/skip_dataset_op.cc
parentb3adb58d84ebb91d893b647ab4081530460fb8ed (diff)
[tf.data] Changing signature of `MakeIterator` to enable propagating error status.
PiperOrigin-RevId: 198772254
Diffstat (limited to 'tensorflow/core/kernels/data/skip_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc13
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index d636c37afe..0177839707 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -47,14 +47,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
return std::unique_ptr<IteratorBase>(
new EmptyIterator({this, strings::StrCat(prefix, "::EmptySkip")}));
- } else if (count_ == 0) {
- // Pass through.
- return input_->MakeIterator(prefix);
} else {
return std::unique_ptr<IteratorBase>(new FiniteIterator(
{this, strings::StrCat(prefix, "::FiniteSkip")}));
@@ -108,9 +105,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,