From c221f04b7efff5929f3a6d090983b52f3aa16166 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 5 Oct 2018 14:44:47 -0700 Subject: Automated rollback of commit ae0bc6f006497cc04a2ee75166d4ec71c7154fd8 PiperOrigin-RevId: 215969360 --- tensorflow/core/kernels/data/BUILD | 14 -- tensorflow/core/kernels/data/dataset_utils.cc | 47 ------ tensorflow/core/kernels/data/dataset_utils.h | 20 --- tensorflow/core/kernels/data/dataset_utils_test.cc | 46 ------ tensorflow/core/kernels/data/filter_dataset_op.cc | 162 +++++++++++-------- .../core/kernels/data/map_and_batch_dataset_op.cc | 180 +++++++++------------ tensorflow/core/kernels/data/map_dataset_op.cc | 56 ++----- .../core/kernels/data/parallel_map_dataset_op.cc | 73 +++------ .../core/kernels/data/parallel_map_iterator.cc | 17 +- .../core/kernels/data/parallel_map_iterator.h | 2 +- .../core/kernels/data/parse_example_dataset_op.cc | 2 +- 11 files changed, 214 insertions(+), 405 deletions(-) delete mode 100644 tensorflow/core/kernels/data/dataset_utils_test.cc (limited to 'tensorflow/core') diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 37c1c54786..451f8c1a6c 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -45,16 +45,6 @@ cc_library( ], ) -tf_cc_test( - name = "dataset_utils_test", - srcs = ["dataset_utils_test.cc"], - deps = [ - ":dataset_utils", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "captured_function", srcs = ["captured_function.cc"], @@ -215,7 +205,6 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", - ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -243,7 +232,6 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", - ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -257,7 +245,6 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", - ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -298,7 +285,6 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", - ":dataset_utils", ":parallel_map_iterator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index a40f7f2146..e10833f525 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,57 +15,10 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { namespace data { -Status ComputeShortCircuitIndices(OpKernelContext* ctx, - const NameAttrList& func, - std::vector* indices) { - FunctionLibraryRuntime::Handle fn_handle; - TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( - func.name(), AttrSlice(&func.attr()), &fn_handle)); - auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { - Status s = ctx->function_library()->ReleaseHandle(fn_handle); - if (!s.ok()) { - LOG(WARNING) << "Failed to release handle: " << s.error_message(); - } - }); - - const FunctionBody* fn_body = - ctx->function_library()->GetFunctionBody(fn_handle); - indices->resize(fn_body->ret_nodes.size()); - for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { - Node* ret_node = fn_body->ret_nodes[i]; - Node* ret_input_node; - TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); - if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { - TF_RETURN_IF_ERROR( - GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); - } else { - indices->clear(); - break; - } - } - return Status::OK(); -} - -std::vector ComputeMoveVector(const std::vector& indices) { - std::map last_use; - for (size_t i = 0; i < indices.size(); ++i) { - last_use[indices[i]] = i; - } - std::vector can_move; - can_move.resize(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) { - can_move[i] = last_use[indices[i]] == i; - } - return can_move; -} - Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index d777062293..6ec1350cd4 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -22,26 +22,6 @@ limitations under the License. namespace tensorflow { namespace data { -// This method is used to determine whether we can short-circuit the evaluation -// of the user-defined function `func`. Short-circuting is possible if every -// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = -// (y,x)`, or `f(x) = (x,x)`). -// -// If short-circuiting is possible, the method stores the mapping from output -// indices to input indices in `indices`. Otherwise, `indices` will be empty. -// -// Returns non-ok status if analysis of the function fails. -// -// TODO(jsimsa): Extend this to support constants as well. -Status ComputeShortCircuitIndices(OpKernelContext* ctx, - const NameAttrList& func, - std::vector* indices); - -// Given a vector that maps output indices to input indices, return a vector -// that identifies for which output indices can we move the input (assuming -// output indices are processed left to right). -std::vector ComputeMoveVector(const std::vector& indices); - Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc deleted file mode 100644 index 43295b8ebb..0000000000 --- a/tensorflow/core/kernels/data/dataset_utils_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/kernels/data/dataset_utils.h" - -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace data { -namespace { - -TEST(DatasetUtils, ComputeMoveVector) { - struct TestCase { - std::vector indices; - std::vector expected; - }; - - TestCase test_cases[] = { - TestCase{{}, {}}, - TestCase{{1}, {true}}, - TestCase{{1, 1}, {false, true}}, - TestCase{{1, 2}, {true, true}}, - TestCase{{1, 1, 2}, {false, true, true}}, - TestCase{{1, 2, 2}, {true, false, true}}, - }; - - for (auto& test_case : test_cases) { - EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices)); - } -} - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index be7d182a1f..00884314a9 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,11 +18,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -33,84 +31,67 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: - using FilterIteratorPredicate = - std::function, bool*)>; - explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { + FunctionLibraryRuntime::Handle pred_handle; + OP_REQUIRES_OK(ctx, + ctx->function_library()->Instantiate( + func_.name(), AttrSlice(&func_.attr()), &pred_handle)); + auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { + OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); + }); + + const FunctionBody* pred_body = + ctx->function_library()->GetFunctionBody(pred_handle); + OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, + errors::InvalidArgument( + "predicate function must have a single return value.")); + Node* ret_node = pred_body->ret_nodes[0]; + Node* ret_input_node; + OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); + std::unique_ptr captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - OP_REQUIRES(ctx, indices.size() <= 1, - errors::InvalidArgument( - "predicate function has more than one return value.")); - - FilterIteratorPredicate filter_pred; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - filter_pred = [raw_captured_func](IteratorContext* ctx, - const std::vector& args, - bool* out_matched) { - std::vector result; - TF_RETURN_IF_ERROR( - raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar()(); - return Status::OK(); - }; + if (ret_input_node->def().op() == "_Arg") { + int32 index = -1; + OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); + *output = new FilterTensorDataset(ctx, input, func_, + std::move(captured_func), index); } else { - filter_pred = [indices](IteratorContext* ctx, - const std::vector& args, - bool* out_matched) { - const Tensor& predicate = args[indices[0]]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar()(); - return Status::OK(); - }; + *output = new FilterFunctionDataset(ctx, input, func_, + std::move(captured_func)); } - - *output = new Dataset(ctx, input, func_, std::move(captured_func), - std::move(filter_pred)); } private: - class Dataset : public DatasetBase { + const int graph_def_version_; + + class FilterDatasetBase : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func, - FilterIteratorPredicate filter_pred) + FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)), - filter_pred_(std::move(filter_pred)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } - ~Dataset() override { input_->Unref(); } + ~FilterDatasetBase() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return MakeUnique( - Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, - filter_pred_); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Filter")})); } const DataTypeVector& output_dtypes() const override { @@ -152,15 +133,17 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + virtual Status EvaluatePredicate(IteratorContext* ctx, + const std::vector& element, + bool* out_matched) const = 0; + private: - class Iterator : public DatasetIterator { + class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, - FilterIteratorPredicate filter_pred) - : DatasetIterator(params), + explicit Iterator(const Params& params) + : DatasetIterator(params), filtered_elements_(0), - dropped_elements_(0), - filter_pred_(std::move(filter_pred)) { + dropped_elements_(0) { std::vector components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -197,7 +180,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR( + dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -267,14 +251,64 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); - const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; + + protected: const std::unique_ptr captured_func_; - const FilterIteratorPredicate filter_pred_; + }; + + class FilterFunctionDataset : public FilterDatasetBase { + public: + using FilterDatasetBase::FilterDatasetBase; + + protected: + Status EvaluatePredicate(IteratorContext* ctx, + const std::vector& element, + bool* out_matched) const override { + // TODO(mrry): Avoid blocking a threadpool thread. We will need to + // stack-rip the iterators and use async kernels. + std::vector result; + TF_RETURN_IF_ERROR( + captured_func_->RunWithBorrowedArgs(ctx, element, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar()(); + return Status::OK(); + } + }; + + class FilterTensorDataset : public FilterDatasetBase { + public: + FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr captured_func, + int32 index) + : FilterDatasetBase(ctx, input, func, std::move(captured_func)), + index_(index) {} + + protected: + Status EvaluatePredicate(IteratorContext* ctx, + const std::vector& element, + bool* out_matched) const override { + const Tensor& predicate = element[index_]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar()(); + return Status::OK(); + } + + private: + const int32 index_; }; private: diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f9aaa3080e..bf08970560 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -30,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -43,10 +41,6 @@ namespace { // transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: - using MapAndBatchIteratorFunction = - std::function, - std::shared_ptr>, StatusCallback)>; - explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -97,66 +91,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapAndBatchIteratorFunction map_func; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - map_func = [raw_captured_func]( - IteratorContext* ctx, const string& prefix, - std::vector args, - std::shared_ptr> out_tensors, - StatusCallback done) { - raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), - std::move(done), prefix); - }; - } else { - std::vector can_move = ComputeMoveVector(indices); - map_func = [indices, can_move]( - IteratorContext* ctx, const string& prefix, - std::vector args, - std::shared_ptr> out_tensors, - StatusCallback done) { - for (size_t i = 0; i < indices.size(); ++i) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } - done(Status::OK()); - }; - } - - *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, - std::move(captured_func), &ctx->eigen_cpu_device(), - std::move(map_func)); + *output = new Dataset(ctx, input, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, func_, + std::move(captured_func), &ctx->eigen_cpu_device()); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector& output_shapes, + const NameAttrList& func, std::unique_ptr captured_func, - const Eigen::ThreadPoolDevice* device, - MapAndBatchIteratorFunction map_func) + const Eigen::ThreadPoolDevice* device) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), + map_fn_(func), captured_func_(std::move(captured_func)), - device_(device), - map_func_(std::move(map_func)) { + device_(device) { input_->Ref(); } @@ -164,9 +123,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return MakeUnique( - Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, - map_func_); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); } const DataTypeVector& output_dtypes() const override { @@ -185,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -207,7 +165,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(map_fn_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -227,14 +185,12 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, - MapAndBatchIteratorFunction map_func) + explicit Iterator(const Params& params) : DatasetIterator(params), mu_(std::make_shared()), cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( - params.dataset->num_parallel_calls_, mu_, cond_var_)), - map_func_(std::move(map_func)) {} + params.dataset->num_parallel_calls_, mu_, cond_var_)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -341,6 +297,44 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; + void Callback(const std::shared_ptr& ctx, + const std::shared_ptr& result, + const std::shared_ptr>& return_values, + int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does not " + "match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + } + void CallCompleted(const std::shared_ptr& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); @@ -369,48 +363,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - std::shared_ptr> return_values = - std::make_shared>(); - auto done = [this, ctx, result, return_values, offset](Status status) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does " - "not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor - // batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - }; - - // Apply the map function on `input_element`, storing the result in - // `return_values`, and invoking `done` when finished. - map_func_(ctx.get(), prefix(), std::move(input_element), - std::move(return_values), std::move(done)); + // Call `captured_func_(input_element)`, using `Callback` to store the + // result in `result`. + (*ctx->runner())(std::bind( + [this, result, offset](std::shared_ptr ctx, + std::vector input_element) { + std::shared_ptr> return_values( + new std::vector()); + dataset()->captured_func_->RunAsync( + ctx.get(), std::move(input_element), return_values.get(), + [this, ctx, result, return_values, offset](Status status) { + Callback(ctx, result, return_values, offset, status); + }, + prefix()); + }, + ctx, std::move(input_element))); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -437,7 +404,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto ctx_copy = std::make_shared(*ctx); + std::shared_ptr ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&Iterator::RunnerThread, this, ctx_copy))); @@ -542,8 +509,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.push_back( - std::make_shared(dataset()->batch_size_)); + batch_results_.emplace_back( + new BatchResult(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -560,8 +527,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - batch_results_.push_back( - std::make_shared(dataset()->batch_size_)); + batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); std::shared_ptr result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -687,8 +653,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; - const MapAndBatchIteratorFunction map_func_; - // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. @@ -707,9 +671,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector output_shapes_; + const NameAttrList map_fn_; const std::unique_ptr captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned - const MapAndBatchIteratorFunction map_func_; }; const int op_version_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 0abb2eb4f3..f112e1dc43 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,9 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -30,9 +28,6 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: - using MapIteratorFunction = std::function, std::vector*)>; - explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -48,36 +43,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapIteratorFunction map_func; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - map_func = [raw_captured_func](IteratorContext* ctx, - std::vector args, - std::vector* out_tensors) { - return raw_captured_func->Run(ctx, std::move(args), out_tensors); - }; - } else { - std::vector can_move = ComputeMoveVector(indices); - map_func = [indices, can_move](IteratorContext* ctx, - std::vector args, - std::vector* out_tensors) { - std::map counts; - for (size_t i = 0; i < indices.size(); ++i) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } - return Status::OK(); - }; - } - *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_, std::move(map_func)); + output_types_, output_shapes_); } private: @@ -87,15 +54,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, - const std::vector& output_shapes, - MapIteratorFunction map_func) + const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes), - map_func_(std::move(map_func)) { + output_shapes_(output_shapes) { input_->Ref(); } @@ -103,8 +68,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return MakeUnique( - Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Map")})); } const DataTypeVector& output_dtypes() const override { @@ -151,8 +116,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, MapIteratorFunction map_func) - : DatasetIterator(params), map_func_(std::move(map_func)) {} + explicit Iterator(const Params& params) + : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -174,7 +139,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - Status s = map_func_(ctx, args, out_tensors); + // TODO(mrry): Avoid blocking a threadpool thread. We will need to + // stack-rip the iterators and use async kernels. + Status s = + dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -199,7 +167,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr input_impl_; - const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -207,7 +174,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; - const MapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index a34bb172d4..6abe6c8338 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -57,49 +56,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - ParallelMapIteratorFunction map_func; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix, - std::vector args, - std::vector* out_tensors, - StatusCallback done) { - raw_captured_func->RunAsync(ctx, std::move(args), out_tensors, - std::move(done), prefix); - }; - if (!use_inter_op_parallelism_) { - map_func = [map_func](IteratorContext* ctx, const string& prefix, - std::vector args, - std::vector* out_tensors, - StatusCallback done) { - (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args), - out_tensors, std::move(done))); - }; - } - } else { - std::vector can_move = ComputeMoveVector(indices); - map_func = [indices, can_move](IteratorContext* ctx, const string& prefix, - std::vector args, - std::vector* out_tensors, - StatusCallback done) { - std::map counts; - for (size_t i = 0; i < indices.size(); ++i) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } - done(Status::OK()); - }; - } - *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, - std::move(captured_func), std::move(map_func)); + std::move(captured_func)); } private: @@ -110,8 +69,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_types, const std::vector& output_shapes, bool use_inter_op_parallelism, - std::unique_ptr captured_func, - ParallelMapIteratorFunction map_func) + std::unique_ptr captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), @@ -119,8 +77,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes), use_inter_op_parallelism_(use_inter_op_parallelism), - captured_func_(std::move(captured_func)), - map_func_(std::move(map_func)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -132,9 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - return NewParallelMapIterator( - {this, strings::StrCat(prefix, "::ParallelMap")}, input_, - std::move(init_func), map_func_, num_parallel_calls_); + const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); + ParallelMapIteratorFunction map_func = + [this, new_prefix](IteratorContext* ctx, + std::vector input_element, + std::vector* result, StatusCallback done) { + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done), new_prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func]( + IteratorContext* ctx, std::vector input_element, + std::vector* result, StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), + result, std::move(done))); + }; + } + + return NewParallelMapIterator({this, new_prefix}, input_, + std::move(init_func), std::move(map_func), + num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -202,7 +176,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; const bool use_inter_op_parallelism_; const std::unique_ptr captured_func_; - const ParallelMapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index ebf41925c9..13bd4b6036 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -180,7 +179,7 @@ class ParallelMapIterator : public DatasetBaseIterator { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto ctx_copy = std::make_shared(*ctx); + std::shared_ptr ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); @@ -209,15 +208,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } + // Call `func_(input_element)`, store the result in `result->return_values`, + // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - // Apply the map function on `input_element`, storing the result in - // `result->return_values`, and invoking `done` when finished. - map_func_(ctx.get(), prefix(), std::move(input_element), - &result->return_values, std::move(done)); + map_func_(ctx.get(), std::move(input_element), &result->return_values, + std::move(done)); } Status ProcessResult(const std::shared_ptr& result, @@ -350,9 +349,9 @@ std::unique_ptr NewParallelMapIterator( const DatasetBase* input_dataset, std::function init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return MakeUnique( - params, input_dataset, std::move(init_func), std::move(map_func), - num_parallel_calls); + return std::unique_ptr( + new ParallelMapIterator(params, input_dataset, std::move(init_func), + std::move(map_func), num_parallel_calls)); } } // namespace data diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 813f13c9e4..dc26c5cf25 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -30,7 +30,7 @@ namespace data { // 3. A `std::vector*` to which the function will write the result. // 4. A `StatusCallback` that should be invoked when the function is complete. using ParallelMapIteratorFunction = - std::function, + std::function, std::vector*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 7de5ea8860..1d1a717062 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - auto map_fn = [this](IteratorContext* ctx, const string& prefix, + auto map_fn = [this](IteratorContext* ctx, std::vector input_element, std::vector* result, StatusCallback done) { (*ctx->runner())([this, ctx, input_element, result, done]() { -- cgit v1.2.3