aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
blob: bae56828dc56efe2e872f27333b0d778d0d45f1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#include <atomic>
#include <utility>

#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/ptr_util.h"

namespace tensorflow {
namespace data {
namespace {

// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.

// TODO(b/116852688): Make coordination between the performance model and this
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 public:
  using MapAndBatchIteratorFunction =
      std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
                         std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;

  explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
      : UnaryDatasetOpKernel(ctx),
        op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
  }

 protected:
  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                   DatasetBase** output) override {
    int64 batch_size;
    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
    OP_REQUIRES(
        ctx, batch_size > 0,
        errors::InvalidArgument("batch_size must be greater than zero."));

    int64 num_parallel_calls;
    switch (op_version_) {
      case 1:
        int64 num_parallel_batches;
        OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches",
                                                &num_parallel_batches));
        num_parallel_calls = num_parallel_batches * batch_size;
        OP_REQUIRES(ctx, num_parallel_batches > 0,
                    errors::InvalidArgument(
                        "num_parallel_batches must be greater than zero."));
        break;
      case 2:
        OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
                                                &num_parallel_calls));
        OP_REQUIRES(ctx,
                    num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
                    errors::InvalidArgument(
                        "num_parallel_calls must be greater than zero."));
        break;
      default:
        OP_REQUIRES(ctx, false,
                    errors::Unimplemented("Unsupported operation version %d.",
                                          op_version_));
    }

    bool drop_remainder;
    OP_REQUIRES_OK(ctx,
                   ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));

    std::unique_ptr<CapturedFunction> captured_func;
    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
                                                 &captured_func));

    std::vector<int> indices;
    OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));

    MapAndBatchIteratorFunction map_func;
    CapturedFunction* raw_captured_func = captured_func.get();
    if (indices.empty()) {
      map_func = [raw_captured_func](
                     IteratorContext* ctx, const string& prefix,
                     std::vector<Tensor> args,
                     std::shared_ptr<std::vector<Tensor>> out_tensors,
                     StatusCallback done) {
        raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
                                    std::move(done), prefix);
      };
    } else {
      std::vector<bool> can_move = ComputeMoveVector(indices);
      map_func = [raw_captured_func, indices, can_move](
                     IteratorContext* ctx, const string& prefix,
                     std::vector<Tensor> args,
                     std::shared_ptr<std::vector<Tensor>> out_tensors,
                     StatusCallback done) {
        const std::vector<Tensor>& captured_inputs =
            raw_captured_func->captured_inputs();
        size_t num_args = args.size();
        for (size_t i = 0; i < indices.size(); ++i) {
          if (indices[i] < num_args) {
            if (can_move[i]) {
              out_tensors->push_back(std::move(args[indices[i]]));
            } else {
              out_tensors->push_back(args[indices[i]]);
            }
          } else {
            out_tensors->push_back(captured_inputs[indices[i] - num_args]);
          }
        }
        done(Status::OK());
      };
    }

    *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
                          drop_remainder, output_types_, output_shapes_,
                          std::move(captured_func), &ctx->eigen_cpu_device(),
                          std::move(map_func));
  }

 private:
  class Dataset : public DatasetBase {
   public:
    Dataset(OpKernelContext* ctx, const DatasetBase* input,
            const NameAttrList& func, int64 batch_size,
            int64 num_parallel_calls, bool drop_remainder,
            const DataTypeVector& output_types,
            const std::vector<PartialTensorShape>& output_shapes,
            std::unique_ptr<CapturedFunction> captured_func,
            const Eigen::ThreadPoolDevice* device,
            MapAndBatchIteratorFunction map_func)
        : DatasetBase(DatasetContext(ctx)),
          input_(input),
          func_(func),
          batch_size_(batch_size),
          num_parallel_calls_(num_parallel_calls),
          drop_remainder_(drop_remainder),
          output_types_(output_types),
          output_shapes_(output_shapes),
          captured_func_(std::move(captured_func)),
          device_(device),
          map_func_(std::move(map_func)) {
      input_->Ref();
    }

    ~Dataset() override { input_->Unref(); }

    std::unique_ptr<IteratorBase> MakeIteratorInternal(
        const string& prefix) const override {
      return MakeUnique<Iterator>(
          Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
          map_func_);
    }

    const DataTypeVector& output_dtypes() const override {
      return output_types_;
    }

    const std::vector<PartialTensorShape>& output_shapes() const override {
      return output_shapes_;
    }

    string DebugString() const override {
      return "MapAndBatchDatasetOp::Dataset";
    }

   protected:
    Status AsGraphDefInternal(SerializationContext* ctx,
                              DatasetGraphDefBuilder* b,
                              Node** output) const override {
      TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
      Node* input_graph_node = nullptr;
      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
      Node* batch_size_node;
      TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
      Node* num_parallel_calls_node;
      TF_RETURN_IF_ERROR(
          b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
      Node* drop_remainder_node;
      TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));

      DataTypeVector other_arguments_types;
      other_arguments_types.reserve(captured_func_->captured_inputs().size());
      std::vector<Node*> other_arguments;
      other_arguments.reserve(captured_func_->captured_inputs().size());
      for (const Tensor& t : captured_func_->captured_inputs()) {
        Node* node;
        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
        other_arguments.emplace_back(node);
        other_arguments_types.emplace_back(t.dtype());
      }
      AttrValue f;
      b->BuildAttrValue(func_, &f);
      AttrValue other_arguments_types_attr;
      b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);

      TF_RETURN_IF_ERROR(b->AddDataset(
          this,
          {std::make_pair(0, input_graph_node),
           std::make_pair(2, batch_size_node),
           std::make_pair(3, num_parallel_calls_node),
           std::make_pair(4, drop_remainder_node)},  // Single tensor inputs.
          {std::make_pair(1, other_arguments)},      // Tensor list inputs.
          {std::make_pair("f", f),
           std::make_pair("Targuments", other_arguments_types_attr)},  // Attrs
          output));
      return Status::OK();
    }

   private:
    class Iterator : public DatasetIterator<Dataset> {
     public:
      explicit Iterator(const Params& params,
                        MapAndBatchIteratorFunction map_func)
          : DatasetIterator<Dataset>(params),
            mu_(std::make_shared<mutex>()),
            cond_var_(std::make_shared<condition_variable>()),
            num_parallel_calls_(std::make_shared<model::SharedState>(
                params.dataset->num_parallel_calls_, mu_, cond_var_)),
            map_func_(std::move(map_func)) {}

      ~Iterator() override {
        mutex_lock l(*mu_);
        // Cancel the runner thread.
        cancelled_ = true;
        cond_var_->notify_all();
        // Wait for all in-flight calls to complete.
        while (num_calls_ > 0) {
          cond_var_->wait(l);
        }
      }

      Status Initialize(IteratorContext* ctx) override {
        mutex_lock l(*mu_);
        AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
        if (num_parallel_calls_->value == kAutoTune) {
          num_parallel_calls_->value = 1;
          AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
                              port::NumSchedulableCPUs());
        } else {
          AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
        }
        TF_RETURN_IF_ERROR(
            dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
        return dataset()->captured_func_->Instantiate(ctx);
      }

      Status GetNextInternal(IteratorContext* ctx,
                             std::vector<Tensor>* out_tensors,
                             bool* end_of_sequence) override {
        std::shared_ptr<BatchResult> result;
        {
          mutex_lock l(*mu_);
          EnsureRunnerThreadStarted(ctx);
          while (batch_results_.empty() ||
                 batch_results_.front()->num_calls > 0) {
            RecordStop(ctx);
            cond_var_->wait(l);
            RecordStart(ctx);
          }
          std::swap(result, batch_results_.front());
          batch_results_.pop_front();
          cond_var_->notify_all();
        }
        return ProcessResult(ctx, result, out_tensors, end_of_sequence);
      }

     protected:
      Status SaveInternal(IteratorStateWriter* writer) override {
        mutex_lock l(*mu_);
        // Wait for all in-flight calls to complete.
        while (num_calls_ > 0) {
          cond_var_->wait(l);
        }
        CHECK_EQ(num_calls_, 0);
        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
        TF_RETURN_IF_ERROR(
            writer->WriteScalar(full_name("call_counter"), call_counter_));
        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"),
                                               batch_results_.size()));
        for (size_t i = 0; i < batch_results_.size(); ++i) {
          TF_RETURN_IF_ERROR(WriteBatchResult(writer, i));
        }
        return Status::OK();
      }

      Status RestoreInternal(IteratorContext* ctx,
                             IteratorStateReader* reader) override {
        mutex_lock l(*mu_);
        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
        TF_RETURN_IF_ERROR(
            reader->ReadScalar(full_name("call_counter"), &call_counter_));
        int64 batch_results_size;
        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"),
                                              &batch_results_size));
        for (int i = 0; i < batch_results_size; ++i) {
          TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
        }
        return Status::OK();
      }

     private:
      // BatchResult encapsulates the output batch, as well as anciliary
      // metadata required to execute the fused map-and-batch operation.
      struct BatchResult {
        explicit BatchResult(int64 batch_size) {
          end_of_input = false;
          num_calls = batch_size;
          num_elements = 0;
          output_allocated = false;
          status = Status::OK();
          status_offset = -1;
        }

        // UpdateStatus updates the batch's aggregate Status.
        //
        // In order to ensure that exactly the first non-OK status is returned
        // (required to make the behavior is observably identical to a
        // sequential execution of map followed by batch), we must also keep
        // track of the offset into the batch that produced `s`.
        void UpdateStatus(const Status& s, int64 offset) {
          if (TF_PREDICT_FALSE(!s.ok())) {
            mutex_lock l(mu);
            if (status.ok() || offset < status_offset) {
              status = s;
              status_offset = offset;
            }
          }
        }

        mutex mu;
        bool end_of_input GUARDED_BY(mu);
        int64 num_elements GUARDED_BY(mu);
        std::vector<Tensor> output;
        bool output_allocated GUARDED_BY(mu);
        Status status GUARDED_BY(mu);
        int64 status_offset GUARDED_BY(mu);
        // Counts the number of outstanding calls for this batch.
        int64 num_calls;  // access guarded by owner's mutex
      };

      void CallCompleted(const std::shared_ptr<BatchResult>& result)
          LOCKS_EXCLUDED(*mu_) {
        mutex_lock l(*mu_);
        num_calls_--;
        result->num_calls--;
        cond_var_->notify_all();
      }

      void CallFunction(std::shared_ptr<IteratorContext> ctx,
                        const std::shared_ptr<BatchResult>& result,
                        int64 offset) LOCKS_EXCLUDED(*mu_) {
        // Get the next input element.
        std::vector<Tensor> input_element;
        bool end_of_input;
        Status status =
            input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
        bool return_early;
        {
          mutex_lock l(result->mu);
          result->end_of_input = result->end_of_input || end_of_input;
          result->status.Update(status);
          return_early = result->end_of_input || !result->status.ok();
        }
        if (return_early) {
          CallCompleted(result);
          return;
        }

        std::shared_ptr<std::vector<Tensor>> return_values =
            std::make_shared<std::vector<Tensor>>();
        auto done = [this, ctx, result, return_values, offset](Status status) {
          result->UpdateStatus(status, offset);
          if (status.ok()) {
            EnsureOutputAllocated(ctx, result, return_values);
            for (size_t i = 0; i < return_values->size(); ++i) {
              const Tensor& tensor = return_values->at(i);
              Tensor* batch = &(result->output)[i];
              if (tensor.NumElements() !=
                  (batch->NumElements() / batch->dim_size(0))) {
                TensorShape batch_shape = batch->shape();
                batch_shape.RemoveDim(0);
                result->UpdateStatus(
                    errors::InvalidArgument(
                        "Cannot add tensor to the batch: number of elements "
                        "does "
                        "not match. Shapes are: [tensor]: ",
                        tensor.shape().DebugString(),
                        ", [batch]: ", batch_shape.DebugString()),
                    offset);
                break;
              }
              // TODO(mrry): Add a version of DoParallelConcat that allows us to
              // move `tensor` where possible, to speed up string tensor
              // batching.
              Status copy_status = ::tensorflow::functor::DoParallelConcat(
                  *dataset()->device_, tensor, offset, batch);
              if (!copy_status.ok()) {
                result->UpdateStatus(copy_status, offset);
                break;
              }
            }
            {
              mutex_lock l(result->mu);
              result->num_elements++;
            }
          }
          CallCompleted(result);
        };

        // Apply the map function on `input_element`, storing the result in
        // `return_values`, and invoking `done` when finished.
        map_func_(ctx.get(), prefix(), std::move(input_element),
                  std::move(return_values), std::move(done));
      }

      Status CopyPartialBatch(Tensor* output, const Tensor& value,
                              int64 num_elements) {
        switch (value.dtype()) {
#define HANDLE_TYPE(type)                                         \
  case DataTypeToEnum<type>::value: {                             \
    auto output_t = output->flat_outer_dims<type>();              \
    auto value_t = value.flat_outer_dims<type>();                 \
    for (size_t i = 0; i < num_elements; i++) {                   \
      output_t.template chip<0>(i) = value_t.template chip<0>(i); \
    }                                                             \
    return Status::OK();                                          \
  }
          TF_CALL_DATASET_TYPES(HANDLE_TYPE);
#undef HANDLE_TYPE
          default:
            return errors::InvalidArgument("Unsupported data type: ",
                                           DataTypeString(value.dtype()));
        }
        return Status::OK();
      }

      void EnsureRunnerThreadStarted(IteratorContext* ctx)
          EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
        if (!runner_thread_) {
          auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
          runner_thread_.reset(ctx->env()->StartThread(
              {}, "runner_thread",
              std::bind(&Iterator::RunnerThread, this, ctx_copy)));
        }
      }

      void EnsureOutputAllocated(
          const std::shared_ptr<IteratorContext>& ctx,
          const std::shared_ptr<BatchResult>& result,
          const std::shared_ptr<std::vector<Tensor>>& return_values) {
        mutex_lock l(result->mu);
        if (result->output_allocated) {
          return;
        }
        const size_t num_components = return_values->size();
        for (size_t i = 0; i < num_components; ++i) {
          TensorShape component_shape({dataset()->batch_size_});
          component_shape.AppendShape(return_values->at(i).shape());
          AllocatorAttributes attr;
          attr.set_gpu_compatible(true);
          Tensor component(ctx->allocator(attr), return_values->at(i).dtype(),
                           component_shape);
          result->output.emplace_back(std::move(component));
        }
        result->output_allocated = true;
      }

      Status ProcessResult(IteratorContext* ctx,
                           const std::shared_ptr<BatchResult>& result,
                           std::vector<Tensor>* out_tensors,
                           bool* end_of_sequence) {
        mutex_lock l(result->mu);
        if (result->num_elements == 0) {
          *end_of_sequence = true;
          return Status::OK();
        }
        // `f` may deliberately raise `errors::OutOfRange` to indicate that we
        // should terminate the iteration early.
        if (!result->status.ok() && !errors::IsOutOfRange(result->status)) {
          // Deallocate tensors allocated for the output.
          result->output.clear();
          *end_of_sequence = false;
          return result->status;
        }
        if (result->num_elements < dataset()->batch_size_) {
          if (dataset()->drop_remainder_) {
            // Deallocate tensors allocated for the output.
            result->output.clear();
            *end_of_sequence = true;
            return Status::OK();
          }
          const std::vector<Tensor>& output = result->output;
          for (size_t i = 0; i < output.size(); ++i) {
            TensorShape component_shape(result->output[i].shape());
            component_shape.set_dim(0, result->num_elements);
            AllocatorAttributes attr;
            attr.set_gpu_compatible(true);
            Tensor component(ctx->allocator(attr), output[i].dtype(),
                             component_shape);
            TF_RETURN_IF_ERROR(
                CopyPartialBatch(&component, output[i], result->num_elements));
            out_tensors->emplace_back(std::move(component));
          }
          // Deallocate tensors allocated for the output.
          result->output.clear();
        } else {
          *out_tensors = std::move(result->output);
        }
        *end_of_sequence = result->num_elements == 0;
        return Status::OK();
      }

      void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
          LOCKS_EXCLUDED(*mu_) {
        std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
        RecordStart(ctx.get());
        auto stop_cleanup =
            gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
        new_calls.reserve(num_parallel_calls_->value);
        auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
          int64 num_parallel_calls = num_parallel_calls_->value;
          int64 max_batch_results =
              (num_parallel_calls + dataset()->batch_size_ - 1) /
              dataset()->batch_size_;
          return num_calls_ >= num_parallel_calls ||
                 (batch_results_.size() > max_batch_results ||
                  (batch_results_.size() == max_batch_results &&
                   call_counter_ % dataset()->batch_size_ == 0));
        };
        while (true) {
          {
            mutex_lock l(*mu_);
            while (!cancelled_ && busy()) {
              RecordStop(ctx.get());
              cond_var_->wait(l);
              RecordStart(ctx.get());
            }

            if (cancelled_) {
              return;
            }

            while (!busy()) {
              if (call_counter_ % dataset()->batch_size_ == 0) {
                batch_results_.push_back(
                    std::make_shared<BatchResult>(dataset()->batch_size_));
              }
              int64 offset = call_counter_++ % dataset()->batch_size_;
              new_calls.emplace_back(batch_results_.back(), offset);
              num_calls_++;
            }
          }

          for (const auto& call : new_calls) {
            CallFunction(ctx, call.first, call.second);
          }
          new_calls.clear();
        }
      }

      Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
                             size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
        batch_results_.push_back(
            std::make_shared<BatchResult>(dataset()->batch_size_));
        std::shared_ptr<BatchResult> result = batch_results_.back();
        string prefix = strings::StrCat("batch_results_", index);
        mutex_lock l(result->mu);
        result->end_of_input = reader->Contains(
            full_name(strings::StrCat(prefix, "_end_of_input")));
        TF_RETURN_IF_ERROR(
            reader->ReadScalar(full_name(strings::StrCat(prefix, "_num_calls")),
                               &result->num_calls));
        TF_RETURN_IF_ERROR(reader->ReadScalar(
            full_name(strings::StrCat(prefix, "_num_elements")),
            &result->num_elements));
        result->output_allocated = reader->Contains(
            full_name(strings::StrCat(prefix, "_output_allocated")));
        int64 output_size;
        TF_RETURN_IF_ERROR(reader->ReadScalar(
            full_name(strings::StrCat(prefix, "_output_size")), &output_size));
        result->output.reserve(output_size);
        for (int i = 0; i < output_size; i++) {
          Tensor t;
          TF_RETURN_IF_ERROR(reader->ReadTensor(
              full_name(strings::StrCat(prefix, "_output_", i)), &t));
          // If the batch was not full, we may have stored only the relevant
          // slice. Since tensors in `BatchResult.output` are expected to
          // have the leading dimension of size batch_size, we build a larger
          // tensor and copy the slice read from the checkpoint into it.
          if (t.dim_size(0) < dataset()->batch_size_) {
            TensorShape component_shape(t.shape());
            component_shape.set_dim(0, dataset()->batch_size_);
            AllocatorAttributes attr;
            attr.set_gpu_compatible(true);
            Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
            TF_RETURN_IF_ERROR(CopyPartialBatch(&new_t, t, t.dim_size(0)));
            result->output.emplace_back(std::move(new_t));
          } else {
            result->output.emplace_back(std::move(t));
          }
        }
        TF_RETURN_IF_ERROR(ReadStatus(
            reader, strings::StrCat(prefix, "_status"), &result->status));
        return Status::OK();
      }

      Status ReadStatus(IteratorStateReader* reader, const string& prefix,
                        Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
        int64 code_int;
        TF_RETURN_IF_ERROR(reader->ReadScalar(
            full_name(strings::StrCat(prefix, "_code")), &code_int));
        error::Code code = static_cast<error::Code>(code_int);

        if (code != error::Code::OK) {
          string error_message;
          TF_RETURN_IF_ERROR(reader->ReadScalar(
              full_name(strings::StrCat(prefix, "_msg")), &error_message));
          *status = Status(code, error_message);
        } else {
          *status = Status::OK();
        }
        return Status::OK();
      }

      Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
          EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
        std::shared_ptr<BatchResult> result = batch_results_[index];
        string prefix = strings::StrCat("batch_results_", index);
        mutex_lock l(result->mu);
        if (result->end_of_input) {
          TF_RETURN_IF_ERROR(writer->WriteScalar(
              full_name(strings::StrCat(prefix, "_end_of_input")), ""));
        }
        TF_RETURN_IF_ERROR(writer->WriteScalar(
            full_name(strings::StrCat(prefix, "_num_calls")),
            result->num_calls));
        TF_RETURN_IF_ERROR(writer->WriteScalar(
            full_name(strings::StrCat(prefix, "_num_elements")),
            result->num_elements));
        if (result->output_allocated) {
          TF_RETURN_IF_ERROR(writer->WriteScalar(
              full_name(strings::StrCat(prefix, "_output_allocated")), ""));
        }
        TF_RETURN_IF_ERROR(writer->WriteScalar(
            full_name(strings::StrCat(prefix, "_output_size")),
            result->output.size()));
        for (int i = 0; i < result->output.size(); i++) {
          // If the batch is not full, we only store the first `num_elements`
          // values. The rest of the batch tensor is *uninitialized* and
          // accessing that will raise msan errors.
          if (result->num_elements < dataset()->batch_size_) {
            TF_RETURN_IF_ERROR(writer->WriteTensor(
                full_name(strings::StrCat(prefix, "_output_", i)),
                result->output[i].Slice(0, result->num_elements)));
          } else {
            TF_RETURN_IF_ERROR(writer->WriteTensor(
                full_name(strings::StrCat(prefix, "_output_", i)),
                result->output[i]));
          }
        }
        TF_RETURN_IF_ERROR(WriteStatus(
            writer, strings::StrCat(prefix, "_status"), result->status));
        return Status::OK();
      }

      Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
                         const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
        TF_RETURN_IF_ERROR(
            writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
                                static_cast<int64>(status.code())));
        if (!status.ok()) {
          TF_RETURN_IF_ERROR(
              writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
                                  status.error_message()));
        }
        return Status::OK();
      }

      // Used for coordination between the main thread, the runner thread, and
      // the callback threads.
      const std::shared_ptr<mutex> mu_;
      // Used for coordination between the main thread, the runner thread, and
      // the callback threads. In particular, the runner thread should only
      // schedule new calls when the number of in-flight calls is less than
      // `num_parallel_calls_->value` and there are slots available in the
      // `batch_results_` buffer.
      const std::shared_ptr<condition_variable> cond_var_;
      // Identifies the maximum number of parallel calls.
      const std::shared_ptr<model::SharedState> num_parallel_calls_;
      const MapAndBatchIteratorFunction map_func_;

      // Counts the number of outstanding calls for this batch.
      int64 num_calls_ GUARDED_BY(*mu_) = 0;
      // Counts the total number of calls.
      int64 call_counter_ GUARDED_BY(*mu_) = 0;
      std::unique_ptr<IteratorBase> input_impl_;
      // Buffer for storing the (intermediate) batch results.
      std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
      std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
      bool cancelled_ GUARDED_BY(*mu_) = false;
    };

    const DatasetBase* const input_;
    const NameAttrList func_;
    const int64 batch_size_;
    const int64 num_parallel_calls_;
    const bool drop_remainder_;
    const DataTypeVector output_types_;
    const std::vector<PartialTensorShape> output_shapes_;
    const std::unique_ptr<CapturedFunction> captured_func_;
    const Eigen::ThreadPoolDevice* device_;  // not owned
    const MapAndBatchIteratorFunction map_func_;
  };

  const int op_version_;
  DataTypeVector output_types_;
  std::vector<PartialTensorShape> output_shapes_;
  NameAttrList func_;
};

REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
                        MapAndBatchDatasetOp);

REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
                        MapAndBatchDatasetOp);

}  // namespace
}  // namespace data
}  // namespace tensorflow