aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-10-05 14:44:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 14:50:10 -0700
commitc221f04b7efff5929f3a6d090983b52f3aa16166 (patch)
treefeae7e4662a1712b6afbf84c75e5a9c5a2bbca99 /tensorflow/core
parent6123677f264c615042a816e713f7f1204685e544 (diff)
Automated rollback of commit ae0bc6f006497cc04a2ee75166d4ec71c7154fd8
PiperOrigin-RevId: 215969360
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/data/BUILD14
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc180
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc56
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc73
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc17
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc2
11 files changed, 214 insertions, 405 deletions
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 37c1c54786..451f8c1a6c 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -45,16 +45,6 @@ cc_library(
],
)
-tf_cc_test(
- name = "dataset_utils_test",
- srcs = ["dataset_utils_test.cc"],
- deps = [
- ":dataset_utils",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
@@ -215,7 +205,6 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
- ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -243,7 +232,6 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
- ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -257,7 +245,6 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
- ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -298,7 +285,6 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
- ":dataset_utils",
":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index a40f7f2146..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,57 +15,10 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
-Status ComputeShortCircuitIndices(OpKernelContext* ctx,
- const NameAttrList& func,
- std::vector<int>* indices) {
- FunctionLibraryRuntime::Handle fn_handle;
- TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
- func.name(), AttrSlice(&func.attr()), &fn_handle));
- auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
- Status s = ctx->function_library()->ReleaseHandle(fn_handle);
- if (!s.ok()) {
- LOG(WARNING) << "Failed to release handle: " << s.error_message();
- }
- });
-
- const FunctionBody* fn_body =
- ctx->function_library()->GetFunctionBody(fn_handle);
- indices->resize(fn_body->ret_nodes.size());
- for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
- Node* ret_node = fn_body->ret_nodes[i];
- Node* ret_input_node;
- TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
- if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
- TF_RETURN_IF_ERROR(
- GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
- } else {
- indices->clear();
- break;
- }
- }
- return Status::OK();
-}
-
-std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
- std::map<int, int> last_use;
- for (size_t i = 0; i < indices.size(); ++i) {
- last_use[indices[i]] = i;
- }
- std::vector<bool> can_move;
- can_move.resize(indices.size());
- for (size_t i = 0; i < indices.size(); ++i) {
- can_move[i] = last_use[indices[i]] == i;
- }
- return can_move;
-}
-
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index d777062293..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,26 +22,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
-// This method is used to determine whether we can short-circuit the evaluation
-// of the user-defined function `func`. Short-circuting is possible if every
-// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
-// (y,x)`, or `f(x) = (x,x)`).
-//
-// If short-circuiting is possible, the method stores the mapping from output
-// indices to input indices in `indices`. Otherwise, `indices` will be empty.
-//
-// Returns non-ok status if analysis of the function fails.
-//
-// TODO(jsimsa): Extend this to support constants as well.
-Status ComputeShortCircuitIndices(OpKernelContext* ctx,
- const NameAttrList& func,
- std::vector<int>* indices);
-
-// Given a vector that maps output indices to input indices, return a vector
-// that identifies for which output indices can we move the input (assuming
-// output indices are processed left to right).
-std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
-
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
deleted file mode 100644
index 43295b8ebb..0000000000
--- a/tensorflow/core/kernels/data/dataset_utils_test.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/kernels/data/dataset_utils.h"
-
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-TEST(DatasetUtils, ComputeMoveVector) {
- struct TestCase {
- std::vector<int> indices;
- std::vector<bool> expected;
- };
-
- TestCase test_cases[] = {
- TestCase{{}, {}},
- TestCase{{1}, {true}},
- TestCase{{1, 1}, {false, true}},
- TestCase{{1, 2}, {true, true}},
- TestCase{{1, 1, 2}, {false, true, true}},
- TestCase{{1, 2, 2}, {true, false, true}},
- };
-
- for (auto& test_case : test_cases) {
- EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
- }
-}
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index be7d182a1f..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,11 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -33,84 +31,67 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
- using FilterIteratorPredicate =
- std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
-
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
+ FunctionLibraryRuntime::Handle pred_handle;
+ OP_REQUIRES_OK(ctx,
+ ctx->function_library()->Instantiate(
+ func_.name(), AttrSlice(&func_.attr()), &pred_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
+ OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
+ });
+
+ const FunctionBody* pred_body =
+ ctx->function_library()->GetFunctionBody(pred_handle);
+ OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
+ errors::InvalidArgument(
+ "predicate function must have a single return value."));
+ Node* ret_node = pred_body->ret_nodes[0];
+ Node* ret_input_node;
+ OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
- OP_REQUIRES(ctx, indices.size() <= 1,
- errors::InvalidArgument(
- "predicate function has more than one return value."));
-
- FilterIteratorPredicate filter_pred;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- filter_pred = [raw_captured_func](IteratorContext* ctx,
- const std::vector<Tensor>& args,
- bool* out_matched) {
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- };
+ if (ret_input_node->def().op() == "_Arg") {
+ int32 index = -1;
+ OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
+ *output = new FilterTensorDataset(ctx, input, func_,
+ std::move(captured_func), index);
} else {
- filter_pred = [indices](IteratorContext* ctx,
- const std::vector<Tensor>& args,
- bool* out_matched) {
- const Tensor& predicate = args[indices[0]];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- };
+ *output = new FilterFunctionDataset(ctx, input, func_,
+ std::move(captured_func));
}
-
- *output = new Dataset(ctx, input, func_, std::move(captured_func),
- std::move(filter_pred));
}
private:
- class Dataset : public DatasetBase {
+ const int graph_def_version_;
+
+ class FilterDatasetBase : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- FilterIteratorPredicate filter_pred)
+ FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)),
- filter_pred_(std::move(filter_pred)) {
+ captured_func_(std::move(captured_func)) {
input_->Ref();
}
- ~Dataset() override { input_->Unref(); }
+ ~FilterDatasetBase() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
- filter_pred_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Filter")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -152,15 +133,17 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ virtual Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const = 0;
+
private:
- class Iterator : public DatasetIterator<Dataset> {
+ class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
- explicit Iterator(const Params& params,
- FilterIteratorPredicate filter_pred)
- : DatasetIterator<Dataset>(params),
+ explicit Iterator(const Params& params)
+ : DatasetIterator<FilterDatasetBase>(params),
filtered_elements_(0),
- dropped_elements_(0),
- filter_pred_(std::move(filter_pred)) {
+ dropped_elements_(0) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -197,7 +180,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(
+ dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -267,14 +251,64 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
- const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
+
+ protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- const FilterIteratorPredicate filter_pred_;
+ };
+
+ class FilterFunctionDataset : public FilterDatasetBase {
+ public:
+ using FilterDatasetBase::FilterDatasetBase;
+
+ protected:
+ Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const override {
+ // TODO(mrry): Avoid blocking a threadpool thread. We will need to
+ // stack-rip the iterators and use async kernels.
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ captured_func_->RunWithBorrowedArgs(ctx, element, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ }
+ };
+
+ class FilterTensorDataset : public FilterDatasetBase {
+ public:
+ FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ int32 index)
+ : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
+ index_(index) {}
+
+ protected:
+ Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const override {
+ const Tensor& predicate = element[index_];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ }
+
+ private:
+ const int32 index_;
};
private:
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index f9aaa3080e..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -30,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -43,10 +41,6 @@ namespace {
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
- using MapAndBatchIteratorFunction =
- std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
- std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
-
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -97,66 +91,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
-
- MapAndBatchIteratorFunction map_func;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- map_func = [raw_captured_func](
- IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::shared_ptr<std::vector<Tensor>> out_tensors,
- StatusCallback done) {
- raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
- std::move(done), prefix);
- };
- } else {
- std::vector<bool> can_move = ComputeMoveVector(indices);
- map_func = [indices, can_move](
- IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::shared_ptr<std::vector<Tensor>> out_tensors,
- StatusCallback done) {
- for (size_t i = 0; i < indices.size(); ++i) {
- if (can_move[i]) {
- out_tensors->push_back(std::move(args[indices[i]]));
- } else {
- out_tensors->push_back(args[indices[i]]);
- }
- }
- done(Status::OK());
- };
- }
-
- *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_,
- std::move(captured_func), &ctx->eigen_cpu_device(),
- std::move(map_func));
+ *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_, func_,
+ std::move(captured_func), &ctx->eigen_cpu_device());
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device,
- MapAndBatchIteratorFunction map_func)
+ const Eigen::ThreadPoolDevice* device)
: DatasetBase(DatasetContext(ctx)),
input_(input),
- func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
+ map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device),
- map_func_(std::move(map_func)) {
+ device_(device) {
input_->Ref();
}
@@ -164,9 +123,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
- map_func_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -185,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -207,7 +165,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(func_, &f);
+ b->BuildAttrValue(map_fn_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -227,14 +185,12 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params,
- MapAndBatchIteratorFunction map_func)
+ explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
- params.dataset->num_parallel_calls_, mu_, cond_var_)),
- map_func_(std::move(map_func)) {}
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
mutex_lock l(*mu_);
@@ -341,6 +297,44 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
+ void Callback(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values,
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does not "
+ "match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ }
+
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
@@ -369,48 +363,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- std::shared_ptr<std::vector<Tensor>> return_values =
- std::make_shared<std::vector<Tensor>>();
- auto done = [this, ctx, result, return_values, offset](Status status) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does "
- "not match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor
- // batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- };
-
- // Apply the map function on `input_element`, storing the result in
- // `return_values`, and invoking `done` when finished.
- map_func_(ctx.get(), prefix(), std::move(input_element),
- std::move(return_values), std::move(done));
+ // Call `captured_func_(input_element)`, using `Callback` to store the
+ // result in `result`.
+ (*ctx->runner())(std::bind(
+ [this, result, offset](std::shared_ptr<IteratorContext> ctx,
+ std::vector<Tensor> input_element) {
+ std::shared_ptr<std::vector<Tensor>> return_values(
+ new std::vector<Tensor>());
+ dataset()->captured_func_->RunAsync(
+ ctx.get(), std::move(input_element), return_values.get(),
+ [this, ctx, result, return_values, offset](Status status) {
+ Callback(ctx, result, return_values, offset, status);
+ },
+ prefix());
+ },
+ ctx, std::move(input_element)));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -437,7 +404,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
@@ -542,8 +509,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.push_back(
- std::make_shared<BatchResult>(dataset()->batch_size_));
+ batch_results_.emplace_back(
+ new BatchResult(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -560,8 +527,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- batch_results_.push_back(
- std::make_shared<BatchResult>(dataset()->batch_size_));
+ batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -687,8 +653,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
- const MapAndBatchIteratorFunction map_func_;
-
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
@@ -707,9 +671,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
- const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 0abb2eb4f3..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,9 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -30,9 +28,6 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- using MapIteratorFunction = std::function<Status(
- IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
-
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -48,36 +43,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
-
- MapIteratorFunction map_func;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- map_func = [raw_captured_func](IteratorContext* ctx,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors) {
- return raw_captured_func->Run(ctx, std::move(args), out_tensors);
- };
- } else {
- std::vector<bool> can_move = ComputeMoveVector(indices);
- map_func = [indices, can_move](IteratorContext* ctx,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors) {
- std::map<int, int> counts;
- for (size_t i = 0; i < indices.size(); ++i) {
- if (can_move[i]) {
- out_tensors->push_back(std::move(args[indices[i]]));
- } else {
- out_tensors->push_back(args[indices[i]]);
- }
- }
- return Status::OK();
- };
- }
-
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_, std::move(map_func));
+ output_types_, output_shapes_);
}
private:
@@ -87,15 +54,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes,
- MapIteratorFunction map_func)
+ const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes),
- map_func_(std::move(map_func)) {
+ output_shapes_(output_shapes) {
input_->Ref();
}
@@ -103,8 +68,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Map")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -151,8 +116,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params, MapIteratorFunction map_func)
- : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -174,7 +139,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- Status s = map_func_(ctx, args, out_tensors);
+ // TODO(mrry): Avoid blocking a threadpool thread. We will need to
+ // stack-rip the iterators and use async kernels.
+ Status s =
+ dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -199,7 +167,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
- const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -207,7 +174,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a34bb172d4..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -57,49 +56,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
-
- ParallelMapIteratorFunction map_func;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors,
- StatusCallback done) {
- raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
- std::move(done), prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors,
- StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
- out_tensors, std::move(done)));
- };
- }
- } else {
- std::vector<bool> can_move = ComputeMoveVector(indices);
- map_func = [indices, can_move](IteratorContext* ctx, const string& prefix,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors,
- StatusCallback done) {
- std::map<int, int> counts;
- for (size_t i = 0; i < indices.size(); ++i) {
- if (can_move[i]) {
- out_tensors->push_back(std::move(args[indices[i]]));
- } else {
- out_tensors->push_back(args[indices[i]]);
- }
- }
- done(Status::OK());
- };
- }
-
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func), std::move(map_func));
+ std::move(captured_func));
}
private:
@@ -110,8 +69,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func,
- ParallelMapIteratorFunction map_func)
+ std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -119,8 +77,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)),
- map_func_(std::move(map_func)) {
+ captured_func_(std::move(captured_func)) {
input_->Ref();
}
@@ -132,9 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(init_func), map_func_, num_parallel_calls_);
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
+ };
+ }
+
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -202,7 +176,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
- const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ebf41925c9..13bd4b6036 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -180,7 +179,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
@@ -209,15 +208,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- // Apply the map function on `input_element`, storing the result in
- // `result->return_values`, and invoking `done` when finished.
- map_func_(ctx.get(), prefix(), std::move(input_element),
- &result->return_values, std::move(done));
+ map_func_(ctx.get(), std::move(input_element), &result->return_values,
+ std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -350,9 +349,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return MakeUnique<ParallelMapIterator>(
- params, input_dataset, std::move(init_func), std::move(map_func),
- num_parallel_calls);
+ return std::unique_ptr<IteratorBase>(
+ new ParallelMapIterator(params, input_dataset, std::move(init_func),
+ std::move(map_func), num_parallel_calls));
}
} // namespace data
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 813f13c9e4..dc26c5cf25 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -30,7 +30,7 @@ namespace data {
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
using ParallelMapIteratorFunction =
- std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
+ std::function<void(IteratorContext*, std::vector<Tensor>,
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 7de5ea8860..1d1a717062 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_fn = [this](IteratorContext* ctx, const string& prefix,
+ auto map_fn = [this](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {