From 0e42fd6d0a88b30ab57959f38c79bea19d745ec3 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 8 Oct 2018 10:14:58 -0700 Subject: [tf.data] Adding specialization for `MapDataset`, `ParallelMapDataset`, and `MapAndBatchDataset` whose user-provided functions have the property that each output argument take its value directly from an input argument (e.g. `lambda x, y: y, x`). This specialization can produce the result without having to schedule the function using the executor. PiperOrigin-RevId: 216206232 --- tensorflow/core/kernels/data/BUILD | 14 ++ tensorflow/core/kernels/data/dataset_utils.cc | 47 ++++++ tensorflow/core/kernels/data/dataset_utils.h | 20 +++ tensorflow/core/kernels/data/dataset_utils_test.cc | 46 +++++ tensorflow/core/kernels/data/filter_dataset_op.cc | 162 +++++++----------- .../core/kernels/data/map_and_batch_dataset_op.cc | 187 +++++++++++++-------- tensorflow/core/kernels/data/map_dataset_op.cc | 62 +++++-- .../core/kernels/data/parallel_map_dataset_op.cc | 79 ++++++--- .../core/kernels/data/parallel_map_iterator.cc | 17 +- .../core/kernels/data/parallel_map_iterator.h | 2 +- .../core/kernels/data/parse_example_dataset_op.cc | 2 +- .../kernel_tests/map_and_batch_test.py | 31 ++++ .../data/kernel_tests/filter_dataset_op_test.py | 2 +- .../data/kernel_tests/map_dataset_op_test.py | 95 +++++++++-- tensorflow/python/data/kernel_tests/test_base.py | 29 ++++ 15 files changed, 565 insertions(+), 230 deletions(-) create mode 100644 tensorflow/core/kernels/data/dataset_utils_test.cc diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 451f8c1a6c..37c1c54786 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -45,6 +45,16 @@ cc_library( ], ) +tf_cc_test( + name = "dataset_utils_test", + srcs = ["dataset_utils_test.cc"], + deps = [ + ":dataset_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "captured_function", srcs = ["captured_function.cc"], @@ -205,6 +215,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -232,6 +243,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -245,6 +257,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -285,6 +298,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", ":parallel_map_iterator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index e10833f525..a40f7f2146 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,10 +15,57 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { namespace data { +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector* indices) { + FunctionLibraryRuntime::Handle fn_handle; + TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( + func.name(), AttrSlice(&func.attr()), &fn_handle)); + auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { + Status s = ctx->function_library()->ReleaseHandle(fn_handle); + if (!s.ok()) { + LOG(WARNING) << "Failed to release handle: " << s.error_message(); + } + }); + + const FunctionBody* fn_body = + ctx->function_library()->GetFunctionBody(fn_handle); + indices->resize(fn_body->ret_nodes.size()); + for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { + Node* ret_node = fn_body->ret_nodes[i]; + Node* ret_input_node; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); + if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { + TF_RETURN_IF_ERROR( + GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); + } else { + indices->clear(); + break; + } + } + return Status::OK(); +} + +std::vector ComputeMoveVector(const std::vector& indices) { + std::map last_use; + for (size_t i = 0; i < indices.size(); ++i) { + last_use[indices[i]] = i; + } + std::vector can_move; + can_move.resize(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + can_move[i] = last_use[indices[i]] == i; + } + return can_move; +} + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6ec1350cd4..d777062293 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -22,6 +22,26 @@ limitations under the License. namespace tensorflow { namespace data { +// This method is used to determine whether we can short-circuit the evaluation +// of the user-defined function `func`. Short-circuting is possible if every +// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = +// (y,x)`, or `f(x) = (x,x)`). +// +// If short-circuiting is possible, the method stores the mapping from output +// indices to input indices in `indices`. Otherwise, `indices` will be empty. +// +// Returns non-ok status if analysis of the function fails. +// +// TODO(jsimsa): Extend this to support constants as well. +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector* indices); + +// Given a vector that maps output indices to input indices, return a vector +// that identifies for which output indices can we move the input (assuming +// output indices are processed left to right). +std::vector ComputeMoveVector(const std::vector& indices); + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc new file mode 100644 index 0000000000..43295b8ebb --- /dev/null +++ b/tensorflow/core/kernels/data/dataset_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_utils.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +TEST(DatasetUtils, ComputeMoveVector) { + struct TestCase { + std::vector indices; + std::vector expected; + }; + + TestCase test_cases[] = { + TestCase{{}, {}}, + TestCase{{1}, {true}}, + TestCase{{1, 1}, {false, true}}, + TestCase{{1, 2}, {true, true}}, + TestCase{{1, 1, 2}, {false, true, true}}, + TestCase{{1, 2, 2}, {true, false, true}}, + }; + + for (auto& test_case : test_cases) { + EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices)); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 00884314a9..be7d182a1f 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,9 +18,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -31,67 +33,84 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: + using FilterIteratorPredicate = + std::function, bool*)>; + explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - FunctionLibraryRuntime::Handle pred_handle; - OP_REQUIRES_OK(ctx, - ctx->function_library()->Instantiate( - func_.name(), AttrSlice(&func_.attr()), &pred_handle)); - auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { - OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); - }); - - const FunctionBody* pred_body = - ctx->function_library()->GetFunctionBody(pred_handle); - OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, - errors::InvalidArgument( - "predicate function must have a single return value.")); - Node* ret_node = pred_body->ret_nodes[0]; - Node* ret_input_node; - OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); - std::unique_ptr captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - if (ret_input_node->def().op() == "_Arg") { - int32 index = -1; - OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); - *output = new FilterTensorDataset(ctx, input, func_, - std::move(captured_func), index); + std::vector indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + OP_REQUIRES(ctx, indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); + + FilterIteratorPredicate filter_pred; + if (indices.empty()) { + CapturedFunction* raw_captured_func = captured_func.get(); + filter_pred = [raw_captured_func](IteratorContext* ctx, + const std::vector& args, + bool* out_matched) { + std::vector result; + TF_RETURN_IF_ERROR( + raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar()(); + return Status::OK(); + }; } else { - *output = new FilterFunctionDataset(ctx, input, func_, - std::move(captured_func)); + filter_pred = [indices](IteratorContext* ctx, + const std::vector& args, + bool* out_matched) { + const Tensor& predicate = args[indices[0]]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar()(); + return Status::OK(); + }; } + + *output = new Dataset(ctx, input, func_, std::move(captured_func), + std::move(filter_pred)); } private: - const int graph_def_version_; - - class FilterDatasetBase : public DatasetBase { + class Dataset : public DatasetBase { public: - FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func) + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr captured_func, + FilterIteratorPredicate filter_pred) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + filter_pred_(std::move(filter_pred)) { input_->Ref(); } - ~FilterDatasetBase() override { input_->Unref(); } + ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::Filter")})); + return MakeUnique( + Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, + filter_pred_); } const DataTypeVector& output_dtypes() const override { @@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - virtual Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const = 0; - private: - class Iterator : public DatasetIterator { + class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) - : DatasetIterator(params), + explicit Iterator(const Params& params, + FilterIteratorPredicate filter_pred) + : DatasetIterator(params), filtered_elements_(0), - dropped_elements_(0) { + dropped_elements_(0), + filter_pred_(std::move(filter_pred)) { std::vector components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR( - dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); + const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; - - protected: const std::unique_ptr captured_func_; - }; - - class FilterFunctionDataset : public FilterDatasetBase { - public: - using FilterDatasetBase::FilterDatasetBase; - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const override { - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - std::vector result; - TF_RETURN_IF_ERROR( - captured_func_->RunWithBorrowedArgs(ctx, element, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar()(); - return Status::OK(); - } - }; - - class FilterTensorDataset : public FilterDatasetBase { - public: - FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func, - int32 index) - : FilterDatasetBase(ctx, input, func, std::move(captured_func)), - index_(index) {} - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const override { - const Tensor& predicate = element[index_]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar()(); - return Status::OK(); - } - - private: - const int32 index_; + const FilterIteratorPredicate filter_pred_; }; private: diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index bf08970560..f45a239793 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -41,6 +43,10 @@ namespace { // transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: + using MapAndBatchIteratorFunction = + std::function, + std::shared_ptr>, StatusCallback)>; + explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -91,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - *output = new Dataset(ctx, input, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, func_, - std::move(captured_func), &ctx->eigen_cpu_device()); + std::vector indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapAndBatchIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func]( + IteratorContext* ctx, const string& prefix, + std::vector args, + std::shared_ptr> out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), + std::move(done), prefix); + }; + } else { + std::vector can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector args, + std::shared_ptr> out_tensors, + StatusCallback done) { + const std::vector& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + + *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, + std::move(captured_func), &ctx->eigen_cpu_device(), + std::move(map_func)); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector& output_shapes, - const NameAttrList& func, std::unique_ptr captured_func, - const Eigen::ThreadPoolDevice* device) + const Eigen::ThreadPoolDevice* device, + MapAndBatchIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), + func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), - map_fn_(func), captured_func_(std::move(captured_func)), - device_(device) { + device_(device), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -123,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); + return MakeUnique( + Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, + map_func_); } const DataTypeVector& output_dtypes() const override { @@ -143,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -165,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(map_fn_, &f); + b->BuildAttrValue(func_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -185,12 +234,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) + explicit Iterator(const Params& params, + MapAndBatchIteratorFunction map_func) : DatasetIterator(params), mu_(std::make_shared()), cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( - params.dataset->num_parallel_calls_, mu_, cond_var_)) {} + params.dataset->num_parallel_calls_, mu_, cond_var_)), + map_func_(std::move(map_func)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -297,44 +348,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; - void Callback(const std::shared_ptr& ctx, - const std::shared_ptr& result, - const std::shared_ptr>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does not " - "match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - } - void CallCompleted(const std::shared_ptr& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); @@ -363,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - // Call `captured_func_(input_element)`, using `Callback` to store the - // result in `result`. - (*ctx->runner())(std::bind( - [this, result, offset](std::shared_ptr ctx, - std::vector input_element) { - std::shared_ptr> return_values( - new std::vector()); - dataset()->captured_func_->RunAsync( - ctx.get(), std::move(input_element), return_values.get(), - [this, ctx, result, return_values, offset](Status status) { - Callback(ctx, result, return_values, offset, status); - }, - prefix()); - }, - ctx, std::move(input_element))); + std::shared_ptr> return_values = + std::make_shared>(); + auto done = [this, ctx, result, return_values, offset](Status status) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor + // batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + }; + + // Apply the map function on `input_element`, storing the result in + // `return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + std::move(return_values), std::move(done)); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -404,7 +444,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&Iterator::RunnerThread, this, ctx_copy))); @@ -509,8 +549,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.emplace_back( - new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -527,7 +567,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared(dataset()->batch_size_)); std::shared_ptr result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -653,6 +694,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; + const MapAndBatchIteratorFunction map_func_; + // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. @@ -671,9 +714,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector output_shapes_; - const NameAttrList map_fn_; const std::unique_ptr captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned + const MapAndBatchIteratorFunction map_func_; }; const int op_version_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index f112e1dc43..6b6ffabf4f 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -28,6 +30,9 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: + using MapIteratorFunction = std::function, std::vector*)>; + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, + std::vector args, + std::vector* out_tensors) { + return raw_captured_func->Run(ctx, std::move(args), out_tensors); + }; + } else { + std::vector can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, std::vector args, + std::vector* out_tensors) { + const std::vector& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + return Status::OK(); + }; + } + *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_); + output_types_, output_shapes_, std::move(map_func)); } private: @@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, - const std::vector& output_shapes) + const std::vector& output_shapes, + MapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes) { + output_shapes_(output_shapes), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::Map")})); + return MakeUnique( + Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); } const DataTypeVector& output_dtypes() const override { @@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} + explicit Iterator(const Params& params, MapIteratorFunction map_func) + : DatasetIterator(params), map_func_(std::move(map_func)) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - Status s = - dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); + Status s = map_func_(ctx, args, out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr input_impl_; + const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; + const MapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 6abe6c8338..3a14924fba 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + ParallelMapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix, + std::vector args, + std::vector* out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors, + std::move(done), prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func](IteratorContext* ctx, const string& prefix, + std::vector args, + std::vector* out_tensors, + StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args), + out_tensors, std::move(done))); + }; + } + } else { + std::vector can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector args, std::vector* out_tensors, + StatusCallback done) { + const std::vector& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, - std::move(captured_func)); + std::move(captured_func), std::move(map_func)); } private: @@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_types, const std::vector& output_shapes, bool use_inter_op_parallelism, - std::unique_ptr captured_func) + std::unique_ptr captured_func, + ParallelMapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), @@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes), use_inter_op_parallelism_(use_inter_op_parallelism), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); - ParallelMapIteratorFunction map_func = - [this, new_prefix](IteratorContext* ctx, - std::vector input_element, - std::vector* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done), new_prefix); - }; - if (!use_inter_op_parallelism_) { - map_func = [map_func]( - IteratorContext* ctx, std::vector input_element, - std::vector* result, StatusCallback done) { - (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), - result, std::move(done))); - }; - } - - return NewParallelMapIterator({this, new_prefix}, input_, - std::move(init_func), std::move(map_func), - num_parallel_calls_); + return NewParallelMapIterator( + {this, strings::StrCat(prefix, "::ParallelMap")}, input_, + std::move(init_func), map_func_, num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; const bool use_inter_op_parallelism_; const std::unique_ptr captured_func_; + const ParallelMapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 13bd4b6036..ebf41925c9 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -179,7 +180,7 @@ class ParallelMapIterator : public DatasetBaseIterator { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); @@ -208,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } - // Call `func_(input_element)`, store the result in `result->return_values`, - // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - map_func_(ctx.get(), std::move(input_element), &result->return_values, - std::move(done)); + // Apply the map function on `input_element`, storing the result in + // `result->return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + &result->return_values, std::move(done)); } Status ProcessResult(const std::shared_ptr& result, @@ -349,9 +350,9 @@ std::unique_ptr NewParallelMapIterator( const DatasetBase* input_dataset, std::function init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr( - new ParallelMapIterator(params, input_dataset, std::move(init_func), - std::move(map_func), num_parallel_calls)); + return MakeUnique( + params, input_dataset, std::move(init_func), std::move(map_func), + num_parallel_calls); } } // namespace data diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index dc26c5cf25..813f13c9e4 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -30,7 +30,7 @@ namespace data { // 3. A `std::vector*` to which the function will write the result. // 4. A `StatusCallback` that should be invoked when the function is complete. using ParallelMapIteratorFunction = - std::function, + std::function, std::vector*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 1d1a717062..7de5ea8860 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - auto map_fn = [this](IteratorContext* ctx, + auto map_fn = [this](IteratorContext* ctx, const string& prefix, std::vector input_element, std::vector* result, StatusCallback done) { (*ctx->runner())([this, ctx, input_element, result, done]() { diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index afd0fc3abf..d444c4082e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) + @parameterized.named_parameters( + ("Identity", None, lambda x: x, None), + ("Replicate", None, lambda x: (x, x), None), + ("Swap", (None, None), lambda x, y: (y, x), None), + ("Project", (None, None), lambda x, y: x, None), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().apply( + batching.map_and_batch(map_fn, batch_size=10)) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn( + *sess.run(self.structuredElement(structure, shape=[10]))) + else: + expected = map_fn( + sess.run(self.structuredElement(structure, shape=[10]))) + self.assertAllEqual(expected, sess.run(get_next)) + + def testShortCircuitCapturedInput(self): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().apply( + batching.map_and_batch(lambda x: captured_t, batch_size=10)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertAllEqual([42] * 10, sess.run(get_next)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 6b7afafa5d..a0c6b37a6d 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -156,7 +156,7 @@ class FilterDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testReturnComponent(self): + def testShortCircuit(self): iterator = ( dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(10), diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 0c372ebb10..4683b1db91 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -622,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _sparse(i)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -649,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -783,19 +783,72 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(all(tids[0] == tid for tid in tids)) # pylint: enable=g-long-lambda + @parameterized.named_parameters( + ("SequentialIdentity", None, lambda x: x, None), + ("SequentialReplicate", None, lambda x: (x, x), None), + ("SequentialSwap", (None, None), lambda x, y: (y, x), None), + ("SequentialProject", (None, None), lambda x, y: x, None), + ("ParallelIdentity", None, lambda x: x, 10), + ("ParallelReplicate", None, lambda x: (x, x), 10), + ("ParallelSwap", (None, None), lambda x, y: (y, x), 10), + ("ParallelProject", (None, None), lambda x, y: x, 10), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().map( + map_fn, num_parallel_calls=num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn(*sess.run(self.structuredElement(structure))) + else: + expected = map_fn(sess.run(self.structuredElement(structure))) + self.assertEqual(expected, sess.run(get_next)) + + @parameterized.named_parameters( + ("Sequential", None), + ("Parallel", 10), + ) + def testShortCircuitCapturedInput(self, num_parallel_calls): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().map( + lambda x: captured_t, num_parallel_calls=num_parallel_calls) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertEqual(42, sess.run(get_next)) + class MapDatasetBenchmark(test.Benchmark): def benchmarkChainOfMaps(self): chain_lengths = [0, 1, 2, 5, 10, 20, 50] for chain_length in chain_lengths: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda x: x + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) for _ in range(chain_length): dataset = dataset_ops.MapDataset( dataset, - lambda x: x, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -813,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset chain length%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", chain_length, median_wall_time)) + (print_label, chain_length, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, name="benchmark_map_dataset_chain_latency_%d%s" % - (chain_length, "_single_threaded" - if not use_inter_op_parallelism else "")) + (chain_length, benchmark_label)) def benchmarkMapFanOut(self): fan_outs = [1, 2, 5, 10, 20, 50, 100] for fan_out in fan_outs: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda *xs: xs + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors( tuple(0 for _ in range(fan_out))).repeat(None) dataset = dataset_ops.MapDataset( dataset, - lambda *xs: xs, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -849,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset fan out%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", fan_out, median_wall_time)) + (print_label, fan_out, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_fan_out_%d%s" % - (fan_out, "_single_threaded" - if not use_inter_op_parallelism else "")) + name="benchmark_map_dataset_fan_out_%d%s" % (fan_out, + benchmark_label)) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b730e10949..b73a94e683 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -19,10 +19,13 @@ from __future__ import print_function import re +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase): with self.assertRaisesRegexp(exception_class, re.escape(expected_message)): self.evaluate(next2()) + + def structuredDataset(self, structure, shape=None, dtype=dtypes.int64): + """Returns a singleton dataset with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self.structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def structuredElement(self, structure, shape=None, dtype=dtypes.int64): + """Returns an element with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self.structuredElement(substructure, shape, dtype) + for substructure in structure + ]) -- cgit v1.2.3