aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc180
1 files changed, 72 insertions, 108 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index f9aaa3080e..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -30,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -43,10 +41,6 @@ namespace {
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
- using MapAndBatchIteratorFunction =
- std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
- std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
-
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -97,66 +91,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
-
- MapAndBatchIteratorFunction map_func;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- map_func = [raw_captured_func](
- IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::shared_ptr<std::vector<Tensor>> out_tensors,
- StatusCallback done) {
- raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
- std::move(done), prefix);
- };
- } else {
- std::vector<bool> can_move = ComputeMoveVector(indices);
- map_func = [indices, can_move](
- IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::shared_ptr<std::vector<Tensor>> out_tensors,
- StatusCallback done) {
- for (size_t i = 0; i < indices.size(); ++i) {
- if (can_move[i]) {
- out_tensors->push_back(std::move(args[indices[i]]));
- } else {
- out_tensors->push_back(args[indices[i]]);
- }
- }
- done(Status::OK());
- };
- }
-
- *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_,
- std::move(captured_func), &ctx->eigen_cpu_device(),
- std::move(map_func));
+ *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_, func_,
+ std::move(captured_func), &ctx->eigen_cpu_device());
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device,
- MapAndBatchIteratorFunction map_func)
+ const Eigen::ThreadPoolDevice* device)
: DatasetBase(DatasetContext(ctx)),
input_(input),
- func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
+ map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device),
- map_func_(std::move(map_func)) {
+ device_(device) {
input_->Ref();
}
@@ -164,9 +123,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
- map_func_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -185,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -207,7 +165,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(func_, &f);
+ b->BuildAttrValue(map_fn_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -227,14 +185,12 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params,
- MapAndBatchIteratorFunction map_func)
+ explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
- params.dataset->num_parallel_calls_, mu_, cond_var_)),
- map_func_(std::move(map_func)) {}
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
mutex_lock l(*mu_);
@@ -341,6 +297,44 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
+ void Callback(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values,
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does not "
+ "match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ }
+
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
@@ -369,48 +363,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- std::shared_ptr<std::vector<Tensor>> return_values =
- std::make_shared<std::vector<Tensor>>();
- auto done = [this, ctx, result, return_values, offset](Status status) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does "
- "not match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor
- // batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- };
-
- // Apply the map function on `input_element`, storing the result in
- // `return_values`, and invoking `done` when finished.
- map_func_(ctx.get(), prefix(), std::move(input_element),
- std::move(return_values), std::move(done));
+ // Call `captured_func_(input_element)`, using `Callback` to store the
+ // result in `result`.
+ (*ctx->runner())(std::bind(
+ [this, result, offset](std::shared_ptr<IteratorContext> ctx,
+ std::vector<Tensor> input_element) {
+ std::shared_ptr<std::vector<Tensor>> return_values(
+ new std::vector<Tensor>());
+ dataset()->captured_func_->RunAsync(
+ ctx.get(), std::move(input_element), return_values.get(),
+ [this, ctx, result, return_values, offset](Status status) {
+ Callback(ctx, result, return_values, offset, status);
+ },
+ prefix());
+ },
+ ctx, std::move(input_element)));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -437,7 +404,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
@@ -542,8 +509,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.push_back(
- std::make_shared<BatchResult>(dataset()->batch_size_));
+ batch_results_.emplace_back(
+ new BatchResult(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -560,8 +527,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- batch_results_.push_back(
- std::make_shared<BatchResult>(dataset()->batch_size_));
+ batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -687,8 +653,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
- const MapAndBatchIteratorFunction map_func_;
-
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
@@ -707,9 +671,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
- const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;