diff options
author | Brennan Saeta <saeta@google.com> | 2018-10-09 11:54:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 11:58:43 -0700 |
commit | 072fcb995a3fd658ee2461b59b159498c710513d (patch) | |
tree | f3def3d3ac6e270ad32e428889a79d662c8bc9cf | |
parent | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff) |
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
22 files changed, 1909 insertions, 81 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt new file mode 100644 index 0000000000..243922d969 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt @@ -0,0 +1,58 @@ +op { + graph_op_name: "ExperimentalNumaMapAndBatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "other_arguments" + description: <<END +A list of tensors, typically values that were captured when building a closure +for `f`. +END + } + in_arg { + name: "batch_size" + description: <<END +A scalar representing the number of elements to accumulate in a +batch. It determines the number of concurrent invocations of `f` that process +elements from `input_dataset` in parallel. +END + } + in_arg { + name: "num_parallel_calls" + description: <<END +A scalar representing the maximum number of parallel invocations of the `map_fn` +function. Applying the `map_fn` on consecutive input elements in parallel has +the potential to improve input pipeline throughput. +END + } + in_arg { + name: "drop_remainder" + description: <<END +A scalar representing whether the last batch should be dropped in case its size +is smaller than desired. +END + } + attr { + name: "f" + description: <<END +A function to apply to the outputs of `input_dataset`. +END + } + summary: "Creates a dataset that fuses mapping with batching." + description: <<END +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +Unlike "MapAndBatchDatasetV2", this dataset uses a NUMA-aware thread scheduling +policy. Because it uses the single-threaded executor, it only supports the +function-based control flow ops. +END +} diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index eae0fa70e8..9596252664 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -335,7 +335,7 @@ class Model { if (name_ == "Map") { return Type::MAP; } - if (name_ == "MapAndBatch") { + if (name_ == "MapAndBatch" || name_ == "NumaMapAndBatch") { return Type::MAP_AND_BATCH; } if (name_ == "PaddedBatch") { diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index ee7c14e3ab..1c553044a8 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -415,6 +415,40 @@ tf_cc_test( ) cc_library( + name = "map_and_batch_numa_aware_replacement", + srcs = ["map_and_batch_numa_aware_replacement.cc"], + hdrs = ["map_and_batch_numa_aware_replacement.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "map_and_batch_numa_aware_replacement_test", + srcs = ["map_and_batch_numa_aware_replacement_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":map_and_batch_numa_aware_replacement", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ], +) + +cc_library( name = "noop_elimination", srcs = ["noop_elimination.cc"], hdrs = [ @@ -490,6 +524,7 @@ cc_library( ":hoist_random_uniform", ":latency_all_edges", ":map_and_batch_fusion", + ":map_and_batch_numa_aware_replacement", ":map_and_filter_fusion", ":map_fusion", ":map_parallelization", diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc index b2eec7220e..1f03c6515c 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace grappler { @@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, {"output_types", gtl::ArraySlice<TensorShape>{}}}); } +NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece batch_size_node_name, + StringPiece num_parallel_calls_node_name, + StringPiece drop_remainder_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapAndBatchDatasetV2", + {string(input_node_name), "", string(batch_size_node_name), + string(num_parallel_calls_node_name), string(drop_remainder_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<TensorShape>{}}}); +} + } // end namespace graph_tests_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h index ca0fde997d..f7891d5e1f 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -29,6 +29,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, StringPiece function_name = "IsZero"); +NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece batch_size_node_name, + StringPiece num_parallel_calls_node_name, + StringPiece drop_remainder_node_name, + StringPiece function_name = "XTimesTwo"); + } // end namespace graph_tests_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc new file mode 100644 index 0000000000..452089eb67 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeNumaAware(const NodeDef& node, MutableGraphView* graph) { + NodeDef numa_aware_node = node; + graph_utils::SetUniqueGraphNodeName("map_and_batch_numa_aware", + graph->GetGraph(), &numa_aware_node); + numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset"); + return numa_aware_node; +} + +} // namespace + +Status MapAndBatchNumaAwareReplacement::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + + for (const NodeDef& node : item.graph.node()) { + if (node.op() != "MapAndBatchDatasetV2") continue; + + auto* numa_node = graph.AddNode(MakeNumaAware(node, &graph)); + graph.ReplaceInput(node, *numa_node); + nodes_to_delete.insert(node.name()); + } + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchNumaAwareReplacement, + "map_and_batch_numa_aware_replacement"); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h new file mode 100644 index 0000000000..3b2acd288b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +class MapAndBatchNumaAwareReplacement : public CustomGraphOptimizer { + public: + MapAndBatchNumaAwareReplacement() = default; + ~MapAndBatchNumaAwareReplacement() override = default; + + string name() const override { + return "map_and_batch_numa_aware_replacement"; + } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override {} +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_ diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc new file mode 100644 index 0000000000..3c5c61d1c2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(MapAndBatchNumaAwareReplacementTest, ReplaceSimple) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + { + NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 5}, {"dtype", DT_INT32}}), + NDef("drop_remainder", "Const", {}, + {{"value", 0}, {"dtype", DT_BOOL}}), + graph_tests_utils::MakeMapAndBatchNode( + "map_and_batch", "range", "batch_size", "num_parallel_calls", + "drop_remainder"), + }, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapAndBatchNumaAwareReplacement optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output)); + EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp( + "ExperimentalNumaMapAndBatchDataset", output)); +} + +TEST(MapAndBatchNumaAawareReplacementTest, ReplaceWithExtraChild) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + { + NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 5}, {"dtype", DT_INT32}}), + NDef("drop_remainder", "Const", {}, + {{"value", 0}, {"dtype", DT_BOOL}}), + graph_tests_utils::MakeMapAndBatchNode( + "map_and_batch", "range", "batch_size", "num_parallel_calls", + "drop_remainder"), + NDef("cache", "CacheDataset", {"map_and_batch"}, {}), + }, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapAndBatchNumaAwareReplacement optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output)); + EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp( + "ExperimentalNumaMapAndBatchDataset", output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output)); + + int numa_map_and_batch_component_id = graph_utils::FindGraphNodeWithOp( + "ExperimentalNumaMapAndBatchDataset", output); + auto& numa_map_and_batch_component = + output.node(numa_map_and_batch_component_id); + EXPECT_EQ(numa_map_and_batch_component.input(0), "range"); + + int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output); + auto& cache_node = output.node(cache_id); + EXPECT_EQ(cache_node.input(0), numa_map_and_batch_component.name()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 43406db3ed..4cf5643bc0 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -103,6 +103,22 @@ tf_kernel_library( ) tf_kernel_library( + name = "numa_map_and_batch_dataset_op", + srcs = ["numa_map_and_batch_dataset_op.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:inplace_ops", + "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset", + "@com_google_absl//absl/memory", + ], +) + +tf_kernel_library( name = "unique_dataset_op", srcs = ["unique_dataset_op.cc"], deps = [ @@ -132,6 +148,7 @@ tf_kernel_library( ":ignore_errors_dataset_op", ":indexed_dataset", ":lmdb_dataset_op", + ":numa_map_and_batch_dataset_op", ":prefetching_kernels", ":threadpool_dataset_op", ":unique_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc new file mode 100644 index 0000000000..d83edb9667 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc @@ -0,0 +1,1135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include <atomic> +#include <utility> + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/inplace_ops_functor.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace data { +namespace { + +// kWindowSize is the fixed constant controlling the number of batch outputs +// each NumaWorkerBlock may be processing at a time. This is currently a +// constant and not user configurable to enable future performance optimizations +// in the implementation. +const int64 kWindowSize = 10; + +// Define a helper for more consistent logging. +#define WORKER_VLOG(verbose_level) \ + VLOG(verbose_level) << "WorkerThread (" << numa_node << ", " << thread_num \ + << "): " + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit NumaMapAndBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); + OP_REQUIRES( + ctx, batch_size > 0, + errors::InvalidArgument("batch_size must be greater than zero.")); + + int64 num_parallel_calls; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + + bool drop_remainder; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); + + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + /* use_inter_op_parallelism = */ false, + &captured_func)); + + *output = new Dataset(ctx, input, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, func_, + std::move(captured_func)); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, + int64 num_parallel_calls, bool drop_remainder, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func) + : DatasetBase(DatasetContext(ctx)), + input_(input), + batch_size_(batch_size), + num_parallel_calls_(num_parallel_calls), + drop_remainder_(drop_remainder), + output_types_(output_types), + output_shapes_(output_shapes), + func_(func), + captured_func_(std::move(captured_func)) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::NumaMapAndBatch")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "NumaMapAndBatchDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* batch_size_node; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); + Node* num_parallel_calls_node; + TF_RETURN_IF_ERROR( + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); + Node* drop_remainder_node; + TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); + + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, input_graph_node), + std::make_pair(2, batch_size_node), + std::make_pair(3, num_parallel_calls_node), + std::make_pair(4, drop_remainder_node)}, // Single tensor inputs. + {std::make_pair(1, other_arguments)}, // Tensor list inputs. + {std::make_pair("f", f), + std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs + output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + mu_(std::make_shared<mutex>()), + autotune_cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, autotune_cond_var_)) { + } + + ~Iterator() override { + mutex_lock l(*mu_); + cancelled_ = true; + VLOG(3) << "NumaMapAndBatchIterator::~Iterator: cancelling operations."; + for (size_t i = 0; i < workers_.size(); ++i) { + workers_[i]->manager.Cancel(); + } + VLOG(3) << "NumaMapAndBatchIterator::~Iterator: waiting for threads to " + "shut down."; + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(*mu_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = std::max(1, port::NUMANumNodes()); + AddTunableParameter(ctx, + /* name = */ "parallelism", + /* state = */ num_parallel_calls_, + /* min = */ num_parallel_calls_->value, + /* max = */ port::NumSchedulableCPUs()); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + } + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(ctx)); + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + auto cleanup = gtl::MakeCleanup( + [] { VLOG(3) << "GetNextInternal call returning."; }); + NumaWorkerBlock* worker = nullptr; + { + mutex_lock l(*mu_); + VLOG(3) << "GetNextInternal call; current block: " << cur_block_; + if (global_end_of_input_) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(EnsureBackgroundThreadsStarted(ctx)); + worker = workers_[cur_block_].get(); + cur_block_ = (cur_block_ + 1) % workers_.size(); + } + TF_RETURN_IF_ERROR(worker->manager.GetBatch( + ctx, dataset()->drop_remainder_, &global_end_of_input_, out_tensors, + end_of_sequence)); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(*mu_); + for (size_t i = 0; i < workers_.size(); ++i) { + if (!workers_[i]->manager.Quiesce()) { + return errors::Cancelled( + "The iterator was deleted before it could reach a " + "checkpointable state."); + } + } + + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("num_workers"), workers_.size())); + + for (size_t i = 0; i < workers_.size(); ++i) { + size_t index = (cur_block_ + i) % workers_.size(); + TF_RETURN_IF_ERROR(workers_[index]->manager.Save(writer, this, i)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(*mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 num_workers = -1; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_workers"), &num_workers)); + // Note: num_workers can be 0 if the iterator wasn't started when + // first checkpointed. + if (num_workers < 0) { + return errors::DataLoss( + "When restoring from checkpoint, we encountered a data " + "consistency error: num_workers has an invalid value: ", + num_workers); + } + if (port::NUMAEnabled()) { + int actual_numa_domains = port::NUMANumNodes(); + if (actual_numa_domains != num_workers && num_workers > 0) { + LOG(WARNING) << "# NUMA domains mismatch when restoring from " + "checkpoint: checkpoint has " + << num_workers + << " NUMA domains, while this host has: " + << actual_numa_domains << " NUMA domains."; + } + } + if (num_workers > 1 && !port::NUMAEnabled()) { + LOG(WARNING) << "NUMA is not enabled for this process, but restoring " + "a checkpoint that assumes " + << num_workers << " NUMA domains."; + } + workers_.resize(num_workers); + for (size_t i = 0; i < num_workers; ++i) { + workers_[i] = MakeUnique<NumaWorkerBlock>(this); + TF_RETURN_IF_ERROR( + workers_[i]->manager.Restore(ctx, reader, this, i)); + } + cur_block_ = 0; + return Status::OK(); + } + + private: + // NumaBlockManager manages all the state for a set of threads pinned to a + // single NUMA domain. + // + // The methods can be divided into 3 categories based on who should call + // them: + // + // (1) RunnerThread: WaitForInputSpace, PushInputs, SetEndOfInput. + // (2) WorkerThread: RetrieveInput, GetBatchTensors. + // RecordBatchEntryComplete + // (3) Client threads: GetBatch, Cancel, Save, Restore. + // + // Internally, we manage state in a circular buffer of size `kWindowSize`. + // There are 3 pointers into the circular buffer, and must maintain the + // following order: (1) next_input_batch_ (corresponding to the next input + // batch to be pulled from the input iterator), (2) next_input_ + // (corresponding to the batch the WorkerThreads should pull from for + // their next inputs), and (3) next_output_ corresponding to the next + // value to be consumed by the output iterator. + // + // Methods return errors::Cancelled if the iteration is cancelled before + // completing. + // + // NumaBlockManager is thread safe. + class NumaBlockManager { + public: + explicit NumaBlockManager(Iterator* itr) : itr_(itr) {} + + // WaitForInputSpace blocks until there is space in the circular buffer + // to begin processing a new batch of elements. + // + // Returns true when there is space, false if the Iterator is cancelled. + bool WaitForInputSpace(IteratorContext* ctx) { + mutex_lock l(mu_); + + size_t next = (next_input_batch_ + 1) % kWindowSize; + DCHECK(next < kWindowSize) << next; + + // Wait for space in the circular buffer. + while (!cancelled_ && batches_[next].state != BatchState::kEmpty) { + VLOG(3) << "Waiting for input space; next: " << next + << ", next_output_: " << next_output_ + << ", next_input_batch_: " << next_input_batch_; + itr_->RecordStop(ctx); + runner_cond_var_.wait(l); + itr_->RecordStart(ctx); + } + if (cancelled_) { + VLOG(3) << "WaitForInputSpace cancelled."; + return false; + } + + DCHECK(batches_[next].state == BatchState::kEmpty); + + next_input_batch_ = next; + return true; + } + + // PushInputs sets the inputs for the next batch as retrieved from the + // input iterator. + void PushInputs(const Status& status, + std::vector<std::vector<Tensor>> inputs) { + mutex_lock l(mu_); + + DCHECK(next_input_ < kWindowSize) << next_input_; + DCHECK(batches_[next_input_batch_].state == BatchState::kEmpty); + DCHECK(batches_[next_input_batch_].next_input_to_process == 0) + << batches_[next_input_batch_].next_input_to_process; + DCHECK(batches_[next_input_batch_].status.ok()) + << batches_[next_input_batch_].status; + + batches_[next_input_batch_].inputs.swap(inputs); + batches_[next_input_batch_].state = BatchState::kInputsFilled; + batches_[next_input_batch_].status.Update(status); + if (batches_[next_input_batch_].status.ok()) { + worker_cond_var_.notify_all(); + } else { + client_cond_var_.notify_all(); + batches_[next_input_batch_].error_index = 0; + } + } + + // SetEndOfInput records the fact that we have reached the end of the + // input iterator, and that we should return end_of_sequence = true when + // we have exhaused all buffered batches. + void SetEndOfInput() { + mutex_lock l(mu_); + reached_eof_ = true; + worker_cond_var_.notify_all(); + client_cond_var_.notify_all(); + } + + // RetrieveInput gets the next input tuple to be mapped by a worker + // thread. + // + // Returns true if an input was retrieved, false if the iterator has + // been cancelled. + bool RetrieveInput(IteratorContext* ctx, std::vector<Tensor>* input, + uint64* index, size_t* sequence_number) { + mutex_lock l(mu_); + + // Wait for inputs to be ready. + while (!cancelled_ && + batches_[next_input_].state != BatchState::kInputsFilled) { + itr_->RecordStop(ctx); + worker_cond_var_.wait(l); + itr_->RecordStart(ctx); + } + + if (cancelled_) { + return false; + } + + DCHECK(batches_[next_input_].next_input_to_process < + batches_[next_input_].inputs.size()) + << "next_input_: " << next_input_ << ", next_input_to_process: " + << batches_[next_input_].next_input_to_process + << ", inputs.size(): " << batches_[next_input_].inputs.size() + << ", state: " << static_cast<int32>(batches_[next_input_].state) + << ", this: " << this; + *index = batches_[next_input_].next_input_to_process; + *sequence_number = next_input_; + input->swap(batches_[next_input_] + .inputs[batches_[next_input_].next_input_to_process]); + // Increment pointers. + batches_[next_input_].next_input_to_process++; + + if (batches_[next_input_].next_input_to_process == + batches_[next_input_].inputs.size()) { + batches_[next_input_].state = BatchState::kAllMapsStarted; + next_input_ = (next_input_ + 1) % kWindowSize; + } + return true; + } + + // GetBatchTensors returns a pointer to the output batch tensors for the + // worker thread to copy into. + // + // allocate_output is a function taking a batch size, and a pointer to + // the output tuple of Tensors to allocate them. The allocate_output + // function is called at most once per output batch. + std::vector<Tensor>* GetBatchTensors( + size_t sequence_number, + std::function<void(size_t, std::vector<Tensor>*)> allocate_output) { + mutex_lock l(mu_); + DCHECK(sequence_number < kWindowSize) << sequence_number; + DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled || + batches_[sequence_number].state == BatchState::kAllMapsStarted) + << sequence_number; + + if (batches_[sequence_number].outputs.empty()) { + allocate_output(batches_[sequence_number].inputs.size(), + &batches_[sequence_number].outputs); + } + return &batches_[sequence_number].outputs; + } + + // RecordBatchEntryComplete records an element of the batch has finished + // copying into the output tensors. + void RecordBatchEntryComplete(size_t sequence_number, uint64 index, + Status s) { + mutex_lock l(mu_); + DCHECK(sequence_number < kWindowSize) << sequence_number; + DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled || + batches_[sequence_number].state == BatchState::kAllMapsStarted) + << sequence_number; + + batches_[sequence_number].num_outputs_complete++; + if (!s.ok() && batches_[sequence_number].error_index > index) { + batches_[sequence_number].status = s; + batches_[sequence_number].error_index = index; + } + + if (batches_[sequence_number].num_outputs_complete == + batches_[sequence_number].inputs.size()) { + DCHECK(batches_[sequence_number].state == + BatchState::kAllMapsStarted); + batches_[sequence_number].state = BatchState::kOutputsComplete; + batches_[sequence_number].inputs.clear(); // Eagerly save memory. + batches_[sequence_number].inputs.shrink_to_fit(); + client_cond_var_.notify_all(); + } + } + + // GetBatch retrieves the next output batch tensors. + Status GetBatch(IteratorContext* ctx, bool drop_remainder, + bool* global_eof, std::vector<Tensor>* out_tensor, + bool* end_of_sequence) { + mutex_lock l(mu_); + // Wait until one of 3 conditions occurs: + // (1) we're cancelled. + // (2) the state becomes kOutputsComplete + // (3) state is empty && reached_eof. + while (!cancelled_ && + batches_[next_output_].state != BatchState::kOutputsComplete && + !(reached_eof_ && + batches_[next_output_].state == BatchState::kEmpty)) { + VLOG(3) << "Waiting in GetBatch."; + itr_->RecordStop(ctx); + client_cond_var_.wait(l); + itr_->RecordStart(ctx); + } + + if (cancelled_) { + return errors::Cancelled( + "Cancelled in NumaMapAndBatch::GetNext call."); + } + + if (reached_eof_ && + batches_[next_output_].state == BatchState::kEmpty) { + VLOG(4) << "GetBatch returning end of sequence."; + *end_of_sequence = true; + *global_eof = true; + return Status::OK(); + } + + VLOG(3) << "Returning output index: " << next_output_ + << ", this: " << this; + + *end_of_sequence = false; + Status s = batches_[next_output_].status; + if (s.ok()) { + out_tensor->swap(batches_[next_output_].outputs); + } + // Handle early termination. + if (errors::IsOutOfRange(s)) { + *global_eof = true; + s = Status::OK(); + if (drop_remainder || batches_[next_output_].error_index == 0) { + *end_of_sequence = true; + } else { + std::vector<Tensor> true_outputs; + for (size_t i = 0; i < batches_[next_output_].outputs.size(); + ++i) { + TensorShape component_shape( + batches_[next_output_].outputs[i].shape()); + component_shape.set_dim(0, batches_[next_output_].error_index); + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + Tensor component(ctx->allocator(attr), + batches_[next_output_].outputs[i].dtype(), + component_shape); + TF_RETURN_IF_ERROR(CopyPartialBatch( + &component, batches_[next_output_].outputs[i], + batches_[next_output_].error_index)); + true_outputs.emplace_back(std::move(component)); + } + out_tensor->swap(true_outputs); + } + } + + batches_[next_output_].Reset(); + next_output_ = (next_output_ + 1) % kWindowSize; + runner_cond_var_.notify_all(); + + return s; + } + + void Cancel() { + mutex_lock l(mu_); + VLOG(3) << "Cancelling NUMA block."; + cancelled_ = true; + runner_cond_var_.notify_all(); + worker_cond_var_.notify_all(); + client_cond_var_.notify_all(); + } + + // Waits until all the worker threads have completed their work and all + // internal state has reached a "safe-point" where we can safely + // checkpoint. + // + // Returns true if completed successfully, false if cancelled while + // waiting. + bool Quiesce() { + mutex_lock l(mu_); + VLOG(3) << "Waiting until the operations have quiesced."; + while (!cancelled_ && !AllMapOperationsFinished()) { + client_cond_var_.wait(l); + } + if (cancelled_) { + return false; + } + return true; + } + + Status Save(IteratorStateWriter* writer, Iterator* itr, size_t index) { + mutex_lock l(mu_); + string prefix = itr->full_name(strings::StrCat("numa_block_", index)); + if (reached_eof_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + strings::StrCat(prefix, "_end_of_input"), "")); + } + for (size_t i = 0; i < kWindowSize; ++i) { + size_t index = (next_output_ + i) % kWindowSize; + if (batches_[index].state == BatchState::kEmpty) { + break; + } + string batch_prefix = strings::StrCat(prefix, "_batch_", i); + TF_RETURN_IF_ERROR(writer->WriteScalar( + strings::StrCat(batch_prefix, "_code"), + static_cast<int64>(batches_[index].status.code()))); + if (!batches_[index].status.ok()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(strings::StrCat(batch_prefix, "_msg"), + batches_[index].status.error_message())); + TF_RETURN_IF_ERROR(writer->WriteScalar( + strings::StrCat(batch_prefix, "_error_index"), + batches_[index].error_index)); + } + + TF_RETURN_IF_ERROR(writer->WriteScalar( + strings::StrCat(batch_prefix, "_output_size"), + batches_[index].outputs.size())); + for (size_t j = 0; j < batches_[index].outputs.size(); ++j) { + string tensor_prefix = + strings::StrCat(batch_prefix, "_output_", j); + if (!batches_[index].status.ok()) { + DCHECK(batches_[index].error_index >= 0 && + batches_[index].error_index < + itr_->dataset()->batch_size_); + // If the batch is not full, we only store the first + // `error_index` values. The rest of the batch tensor might not + // be initialized, and accessing that will raise msan errors. + TF_RETURN_IF_ERROR(writer->WriteTensor( + tensor_prefix, batches_[index].outputs[j].Slice( + 0, batches_[index].error_index))); + } else { + TF_RETURN_IF_ERROR(writer->WriteTensor( + tensor_prefix, batches_[index].outputs[j])); + } + } + } + return Status::OK(); + } + + Status Restore(IteratorContext* ctx, IteratorStateReader* reader, + Iterator* itr, size_t index) { + mutex_lock l(mu_); + if (reached_eof_) { + return errors::FailedPrecondition( + "Already reached the end of the sequence."); + } + string prefix = itr->full_name(strings::StrCat("numa_block_", index)); + reached_eof_ = + reader->Contains(strings::StrCat(prefix, "_end_of_input")); + for (size_t i = 0; i < kWindowSize; ++i) { + string batch_prefix = strings::StrCat(prefix, "_batch_", i); + if (!reader->Contains(strings::StrCat(batch_prefix, "_code"))) { + break; + } + Batch batch; + batch.state = BatchState::kOutputsComplete; + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat(batch_prefix, "_code"), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat(batch_prefix, "_msg"), &error_message)); + batch.status = Status(code, error_message); + int64 error_index_int = -1; + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat(batch_prefix, "_error_index"), + &error_index_int)); + if (error_index_int < 0 || + error_index_int > itr->dataset()->batch_size_) { + return errors::FailedPrecondition( + "Error index out of bounds when restoring from checkpoint; " + "error index: ", + error_index_int); + } + batch.error_index = static_cast<size_t>(error_index_int); + } + int64 output_size = -1; + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat(batch_prefix, "_output_size"), &output_size)); + batch.outputs.reserve(output_size); + for (size_t j = 0; j < output_size; ++j) { + string tensor_name = strings::StrCat(batch_prefix, "_output_", j); + Tensor t; + TF_RETURN_IF_ERROR(reader->ReadTensor(tensor_name, &t)); + batch.outputs.emplace_back(std::move(t)); + } + batches_[i] = std::move(batch); + } + return Status::OK(); + } + + private: + bool AllMapOperationsFinished() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (size_t i = 0; i < kWindowSize; ++i) { + if (batches_[i].state == BatchState::kInputsFilled || + batches_[i].state == BatchState::kAllMapsStarted) { + return false; + } + if (batches_[i].state != BatchState::kOutputsComplete && + !reached_eof_) { + return false; + } + } + return true; + } + + // Batches begin in the `kEmpty` state. Once the RunnerThread has + // filled the `inputs` to a `Batch`, it transitions to the + // `kInputsFilled` state. At this point, the Worker threads run the map + // function and copy the outputs appropriately. Once all worker threads + // have started, it transitions to `kAllMapsStarted`. After the outputs + // are complete, the GetNext call can consume the outputs, and return + // the batch to the kEmpty state. + enum class BatchState { + kEmpty, + kInputsFilled, + kAllMapsStarted, + kOutputsComplete, + }; + + // Batch captures all the state of an output batch as it progresses + // through the machinery. Once the RunnerThread fills inputs, it + // transitions to `kInputsFilled`. At this point, the worker threads can + // work on it, incrementing outputs_complete for every element of the + // input set that is copied into the output Tensors. Once all the input + // tuples have been processed (i.e. num_outputs_complete == + // inputs.size()), it transitions to the `kOutputsComplete` stage, where + // it is ready to be returned by a `GetBatch` call (called from + // `GetNextInternal`). + struct Batch { + BatchState state; + // Aggregates the Status of the input iterator's GetNext + // calls, in addition to the Status of the map function invocations. + // + // In the case where multiple non-OK statuses are encountered, we + // return the first one encountered. + Status status; + // In order to return the correct error status, we keep track of the + // error_index. + size_t error_index; + // The batch_size input tuples (or fewer in the case of the last + // batch). + // TODO(saeta): Avoid re-allocating vectors all the time! + std::vector<std::vector<Tensor>> inputs; + std::vector<Tensor> outputs; + size_t next_input_to_process; + size_t num_outputs_complete; + + Batch() { Reset(); } + + // Resets the Batch state (e.g. after consuming the outputs). + void Reset() { + state = BatchState::kEmpty; + status = Status::OK(); + inputs.clear(); + inputs.shrink_to_fit(); + outputs.clear(); + outputs.shrink_to_fit(); + next_input_to_process = 0; + num_outputs_complete = 0; + error_index = -1; + } + }; + + Iterator* itr_; // Not owned. + mutex mu_; + Batch batches_[kWindowSize] GUARDED_BY(mu_); + size_t next_input_batch_ GUARDED_BY(mu_) = -1; + size_t next_input_ GUARDED_BY(mu_) = 0; + size_t next_output_ GUARDED_BY(mu_) = 0; + bool cancelled_ GUARDED_BY(mu_) = false; + bool reached_eof_ GUARDED_BY(mu_) = false; + + // The runner thread waits on this condition variable for space to be + // available. When the client thread takes a value out of the circular + // buffer, it notifies this condition variable that space is now + // available. + condition_variable runner_cond_var_ GUARDED_BY(mu_); + // The worker threads wait on this condition variable for available + // inputs. When the runner thread makes new inputs available, it + // notifies this condition variable. + condition_variable worker_cond_var_ GUARDED_BY(mu_); + // The client threads wait on this condition variable for avaiable + // batched outputs. When worker threads complete a batch, they notify + // this condition variable. + condition_variable client_cond_var_ GUARDED_BY(mu_); + }; + // Mark NumaBlockManager as a friend of Iterator in order to call + // protected Iterator methods during checkpointing. + friend NumaBlockManager; + + struct NumaWorkerBlock { + NumaBlockManager manager; + // TODO(saeta): Migrate to BackgroundWorker. + std::vector<std::unique_ptr<Thread>> threads; + + explicit NumaWorkerBlock(Iterator* itr) : manager(itr) {} + }; + + static void CustomNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) { + ptr->~NumaWorkerBlock(); + port::NUMAFree(ptr, sizeof(NumaWorkerBlock)); + } + static void DefaultNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) { + delete ptr; + } + + static Status CopyPartialBatch(Tensor* output, const Tensor& value, + int64 num_elements) { + switch (value.dtype()) { +#define HANDLE_TYPE(type) \ + case DataTypeToEnum<type>::value: { \ + auto output_t = output->flat_outer_dims<type>(); \ + auto value_t = value.flat_outer_dims<type>(); \ + for (size_t i = 0; i < num_elements; i++) { \ + output_t.template chip<0>(i) = value_t.template chip<0>(i); \ + } \ + return Status::OK(); \ + } + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + return errors::InvalidArgument("Unsupported data type: ", + DataTypeString(value.dtype())); + } + return Status::OK(); + } + + Status EnsureBackgroundThreadsStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (curr_num_parallel_calls_ >= num_parallel_calls_->value) { + // All necessary threads have been started. + curr_num_parallel_calls_ = num_parallel_calls_->value; + return Status::OK(); + } + + VLOG(4) << "Starting workers"; + bool numa_enabled = port::NUMAEnabled(); + + if (!numa_enabled) { + LOG(INFO) << "NUMA not enabled on this host."; + } + + int num_numa_nodes = port::NUMANumNodes(); + if (num_numa_nodes < 1) { + return errors::Internal("The number of NUMA nodes is invalid: ", + num_numa_nodes); + } + + // Only resize when empty to support restoring from checkpoints. + if (workers_.empty()) { + VLOG(3) << "# NUMA Nodes: " << num_numa_nodes + << ", # Parallel Calls: " << num_parallel_calls_->value; + workers_.resize(num_numa_nodes); + } else { + num_numa_nodes = workers_.size(); + } + + // Round up num_parallel_calls, with a minimum of 1. + const size_t num_threads_per_block = + std::max(1LL, (num_parallel_calls_->value + num_numa_nodes - 1) / + num_numa_nodes); + + VLOG(3) << "Starting " << num_threads_per_block * num_numa_nodes + << " worker threads, with " << num_threads_per_block + << " threads per block."; + + // Only allocate new_ctx if required. + std::shared_ptr<IteratorContext> new_ctx; + + for (int i = 0; i < num_numa_nodes; ++i) { + if (!workers_[i]) { + if (numa_enabled) { + // Allocate in appropriate NUMA domain. + // 4k page align. + void* ptr = port::NUMAMalloc(i, sizeof(NumaWorkerBlock), 0); + if (ptr != nullptr) { + NumaWorkerBlock* block = new (ptr) NumaWorkerBlock(this); + workers_[i] = + std::unique_ptr<NumaWorkerBlock, + std::function<void(NumaWorkerBlock*)>>( + block, CustomNumaWorkerBlockDeleter); + } else { + LOG(ERROR) << "Could not NUMA-allocate worker block: " << i; + } + } + // If the NUMA allocation fails, or NUMA is not enabled. + if (!workers_[i]) { + workers_[i] = + std::unique_ptr<NumaWorkerBlock, + std::function<void(NumaWorkerBlock*)>>( + new NumaWorkerBlock(this), DefaultNumaWorkerBlockDeleter); + } + } + // Be sure to start threads if num_parallel_calls_ has changed. + for (size_t j = workers_[i]->threads.size(); + j < num_threads_per_block; ++j) { + VLOG(3) << "Starting worker " << i << ", " << j; + if (!new_ctx) { + new_ctx = std::make_shared<IteratorContext>(*ctx); + } + workers_[i]->threads.emplace_back(ctx->env()->StartThread( + {}, + strings::StrCat("numa_map_and_batch_block_", i, "_thread_", j), + [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); })); + VLOG(3) << "Worker " << i << ", " << j << " successfully started."; + } + } + if (!runner_thread_) { + if (!new_ctx) { + new_ctx = std::make_shared<IteratorContext>(*ctx); + } + runner_thread_.reset(ctx->env()->StartThread( + {}, "numa_map_runner_thread", + [this, new_ctx] { RunnerThread(new_ctx); })); + } + VLOG(3) << "All workers & runner thread started."; + return Status::OK(); + } + + void AllocateOutput(IteratorContext* ctx, size_t batch_size, + const std::vector<Tensor>& map_fn_outputs, + std::vector<Tensor>* batch_outputs) { + DCHECK(dataset()->output_dtypes().size() == + dataset()->output_shapes().size()); + DCHECK(map_fn_outputs.size() == dataset()->output_dtypes().size()); + for (size_t i = 0; i < dataset()->output_dtypes().size(); ++i) { + TensorShape component_shape({static_cast<uint32>(batch_size)}); + component_shape.AppendShape(map_fn_outputs.at(i).shape()); + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + Tensor component(ctx->allocator(attr), map_fn_outputs.at(i).dtype(), + component_shape); + batch_outputs->emplace_back(std::move(component)); + } + } + + void RunnerThread(std::shared_ptr<IteratorContext> ctx) + LOCKS_EXCLUDED(mu_) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, &ctx] { + // Set end of input on all the managers in order to clean up in an + // orderly fashion. + VLOG(3) << "Setting End of Input on workers_[*]->manager"; + for (size_t i = 0; i < workers_.size(); ++i) { + workers_[i]->manager.SetEndOfInput(); + } + RecordStop(ctx.get()); + }); + + const size_t num_blocks = workers_.size(); + + while (true) { + for (size_t block = 0; block < num_blocks; ++block) { + VLOG(4) << "RunnerThread waiting for input space in block: " + << block; + if (TF_PREDICT_FALSE( + !workers_[block]->manager.WaitForInputSpace(ctx.get()))) { + VLOG(3) << "RunnerThread exiting due to cancellation."; + return; + } + VLOG(4) << "RunnerThread has space; pulling on upstream for block " + << block; + + Status s; + std::vector<std::vector<Tensor>> inputs; + bool end_of_sequence = false; + for (size_t i = 0; i < dataset()->batch_size_; ++i) { + std::vector<Tensor> tuple; + s.Update( + input_impl_->GetNext(ctx.get(), &tuple, &end_of_sequence)); + if (!s.ok()) { + break; + } + if (end_of_sequence) { + VLOG(4) << "Runner thread encountered end of sequence."; + if (dataset()->drop_remainder_) { + return; + } + break; + } + inputs.push_back(std::move(tuple)); + } + + VLOG(4) << "Moving inputs to block " << block + << ", which has size: " << inputs.size(); + if (!s.ok() || !inputs.empty()) { + workers_[block]->manager.PushInputs(s, std::move(inputs)); + VLOG(4) << "Inputs moved into block " << block; + } + if (end_of_sequence) { + return; + } + } + } + } + + void WorkerThread(std::shared_ptr<IteratorContext> ctx, + const int numa_node, const int thread_num) { + RecordStart(ctx.get()); + WORKER_VLOG(3) << "started."; + auto stop_cleanup = + gtl::MakeCleanup([this, numa_node, thread_num, &ctx]() { + RecordStop(ctx.get()); + WORKER_VLOG(3) << "exiting."; + }); + + NumaWorkerBlock* block = workers_[numa_node].get(); + port::NUMASetThreadNodeAffinity(numa_node); + const int num_numa_nodes = port::NUMANumNodes(); + const int minimum_num_parallel_calls = thread_num * num_numa_nodes; + + while (true) { + // Put threads to sleep based on autotuner. + { + mutex_lock l(*mu_); + while (minimum_num_parallel_calls >= num_parallel_calls_->value && + !cancelled_) { + RecordStop(ctx.get()); + autotune_cond_var_->wait(l); + RecordStart(ctx.get()); + } + if (cancelled_) { + return; + } + } + + std::vector<Tensor> input; + uint64 index = 0; + size_t sequence_number = 0; + WORKER_VLOG(4) << "retrieving input."; + { + tracing::ScopedActivity trace( + "NumaMapAndBatch::Iterator::Worker::RetrieveInput"); + if (!block->manager.RetrieveInput(ctx.get(), &input, &index, + &sequence_number)) { + return; + } + } + + WORKER_VLOG(4) << "retrieved input; index: " << index + << ", sequence_number: " << sequence_number; + + std::vector<Tensor> return_values; + Status s; + { + tracing::ScopedActivity trace( + "NumaMapAndBatch::Iterator::Worker::FunctionExecution"); + s = dataset()->captured_func_->Run(ctx.get(), std::move(input), + &return_values); + } + WORKER_VLOG(4) << "ran function for index: " << index + << ", sequence_number: " << sequence_number; + + if (s.ok()) { + std::vector<Tensor>* output = block->manager.GetBatchTensors( + sequence_number, + [this, ctx, &return_values](size_t batch_size, + std::vector<Tensor>* output) { + AllocateOutput(ctx.get(), batch_size, return_values, output); + }); + WORKER_VLOG(4) << "copying tensors to batch output."; + { + tracing::ScopedActivity trace( + "NumaMapAndBatch::Iterator::Worker::BatchCopy"); + for (size_t i = 0; i < return_values.size() && s.ok(); ++i) { + Tensor& tensor = return_values.at(i); + Tensor* batch = &output->at(i); + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + s.Update(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch->shape().DebugString())); + break; + } + s.Update(batch_util::CopyElementToSlice(std::move(tensor), + batch, index)); + } + } + } + + block->manager.RecordBatchEntryComplete(sequence_number, index, s); + WORKER_VLOG(4) << "finished index: " << index + << ", sequence_number: " << sequence_number; + } + } + + // mu_ protects shared internal state and is used to coordinate between + // the auto-tuner, client threads, worker threads, and the runner thread. + const std::shared_ptr<mutex> mu_; + const std::shared_ptr<condition_variable> autotune_cond_var_; + // The maximum number of parallel calls (can be auto-tuned). + const std::shared_ptr<model::SharedState> num_parallel_calls_; + + // Caches the last-seen value of num_parallel_calls_->value to + // short-circuit starting workers. + int64 curr_num_parallel_calls_ GUARDED_BY(*mu_) = 0; + + std::unique_ptr<IteratorBase> input_impl_; + int64 cur_block_ GUARDED_BY(*mu_) = 0; + bool global_end_of_input_ GUARDED_BY(*mu_) = false; + bool cancelled_ GUARDED_BY(*mu_) = false; + std::vector<std::unique_ptr<NumaWorkerBlock, + std::function<void(NumaWorkerBlock*)>>> + workers_; // Const after initialization. + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + }; + + const DatasetBase* const input_; + const int64 batch_size_; + const int64 num_parallel_calls_; + const bool drop_remainder_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + const NameAttrList func_; + const std::unique_ptr<CapturedFunction> captured_func_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList func_; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalNumaMapAndBatchDataset").Device(DEVICE_CPU), + NumaMapAndBatchDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f45a239793..bae56828dc 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -324,6 +324,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } private: + // BatchResult encapsulates the output batch, as well as anciliary + // metadata required to execute the fused map-and-batch operation. struct BatchResult { explicit BatchResult(int64 batch_size) { end_of_input = false; @@ -331,11 +333,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { num_elements = 0; output_allocated = false; status = Status::OK(); + status_offset = -1; } - void UpdateStatus(const Status& s) { - mutex_lock l(mu); - status.Update(s); + // UpdateStatus updates the batch's aggregate Status. + // + // In order to ensure that exactly the first non-OK status is returned + // (required to make the behavior is observably identical to a + // sequential execution of map followed by batch), we must also keep + // track of the offset into the batch that produced `s`. + void UpdateStatus(const Status& s, int64 offset) { + if (TF_PREDICT_FALSE(!s.ok())) { + mutex_lock l(mu); + if (status.ok() || offset < status_offset) { + status = s; + status_offset = offset; + } + } } mutex mu; @@ -344,6 +358,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor> output; bool output_allocated GUARDED_BY(mu); Status status GUARDED_BY(mu); + int64 status_offset GUARDED_BY(mu); // Counts the number of outstanding calls for this batch. int64 num_calls; // access guarded by owner's mutex }; @@ -379,7 +394,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::shared_ptr<std::vector<Tensor>> return_values = std::make_shared<std::vector<Tensor>>(); auto done = [this, ctx, result, return_values, offset](Status status) { - result->UpdateStatus(status); + result->UpdateStatus(status, offset); if (status.ok()) { EnsureOutputAllocated(ctx, result, return_values); for (size_t i = 0; i < return_values->size(); ++i) { @@ -389,11 +404,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { (batch->NumElements() / batch->dim_size(0))) { TensorShape batch_shape = batch->shape(); batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does " - "not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); + result->UpdateStatus( + errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements " + "does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString()), + offset); break; } // TODO(mrry): Add a version of DoParallelConcat that allows us to @@ -402,7 +420,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status copy_status = ::tensorflow::functor::DoParallelConcat( *dataset()->device_, tensor, offset, batch); if (!copy_status.ok()) { - result->UpdateStatus(copy_status); + result->UpdateStatus(copy_status, offset); break; } } diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index f6bd5dce26..bbbecc50f8 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -138,6 +138,32 @@ REGISTER_OP("ExperimentalAssertNextDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ExperimentalNumaMapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_calls: int64") + .Input("drop_remainder: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Use index from the end to retrieve the Input shapes, + // so that to avoid guessing the length of "other_arguments". + // batch_size, num_parallel_batches, and drop_remainder are 0-D scalars. + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); + TF_RETURN_IF_ERROR( + c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); + TF_RETURN_IF_ERROR( + c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); + + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("ExperimentalLMDBDataset") .Input("filenames: string") .Output("handle: variant") diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index d444c4082e..5ead6d1c75 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -38,12 +39,17 @@ from tensorflow.python.platform import test class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( - ("Default", None, None), - ("SequentialCalls", 1, None), - ("ParallelCalls", 2, None), - ("ParallelBatches", None, 10), + ("Default", None, None, False), + ("SequentialCalls", 1, None, False), + ("ParallelCalls", 2, None, False), + ("ParallelBatches", None, 10, False), + ("DefaultNUMA", None, None, True), + ("SequentialCallsNUMA", 1, None, True), + ("ParallelCallsNUMA", 2, None, True), + ("ParallelBatchesNUMA", None, 10, True), ) - def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches, + numa_aware): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -57,14 +63,20 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = ( + dataset = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, num_parallel_calls=num_parallel_calls, - num_parallel_batches=num_parallel_batches)) - .make_initializable_iterator()) + num_parallel_batches=num_parallel_batches))) + + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + + iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -115,16 +127,25 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) @parameterized.named_parameters( - ("Even", False), - ("Uneven", True), + ("Even", False, False), + ("Uneven", True, False), + ("EvenNUMA", False, True), + ("UnevenNUMA", True, True), ) - def testMapAndBatchPartialBatch(self, drop_remainder): - iterator = ( + def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware): + dataset = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( lambda x: array_ops.reshape(x * x, [1]), batch_size=4, - drop_remainder=drop_remainder)).make_one_shot_iterator()) + drop_remainder=drop_remainder))) + + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + if drop_remainder: self.assertEqual([4, 1], iterator.output_shapes.as_list()) else: @@ -138,11 +159,21 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchYieldsPartialBatch(self): - iterator = (dataset_ops.Dataset.range(10) - .apply(batching.map_and_batch( - lambda x: array_ops.reshape(x * x, [1]), 4)) - .make_one_shot_iterator()) + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchYieldsPartialBatch(self, numa_aware): + dataset = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4))) + + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + + iterator = dataset.make_one_shot_iterator() self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() with self.cached_session() as sess: @@ -152,10 +183,19 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchParallelGetNext(self): - iterator = (dataset_ops.Dataset.range(50000) - .apply(batching.map_and_batch(lambda x: x, batch_size=100)) - .make_one_shot_iterator()) + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchParallelGetNext(self, numa_aware): + dataset = dataset_ops.Dataset.range(50000).apply( + batching.map_and_batch(lambda x: x, batch_size=100)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + elements = [] for _ in range(100): elements.append(iterator.get_next()) @@ -165,17 +205,26 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): got.sort(key=lambda x: x[0]) expected = [] for j in range(100): - expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): sess.run(elements) - def testMapAndBatchParallelGetNextDropRemainder(self): - iterator = ( - dataset_ops.Dataset.range(49999).apply( - batching.map_and_batch( - lambda x: x, batch_size=100, drop_remainder=True)) - .make_one_shot_iterator()) + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware): + dataset = dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + elements = [] for _ in range(100): elements.append(iterator.get_next()) @@ -185,19 +234,29 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): got.sort(key=lambda x: x[0]) expected = [] for j in range(100): - expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): sess.run(elements) - def testMapAndBatchSparse(self): + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchSparse(self, numa_aware): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).apply( - batching.map_and_batch(_sparse, 5)).make_initializable_iterator() + dataset = dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(_sparse, 5)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer get_next = iterator.get_next() @@ -214,21 +273,33 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testMapAndBatchFails(self): + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchFails(self, numa_aware): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( - dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testMapAndBatchShapeMismatch(self): + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchShapeMismatch(self, numa_aware): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -240,9 +311,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int32) batch_size = 4 - iterator = ( - dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -251,7 +326,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): "number of elements does not match"): sess.run(get_next) - def testMapAndBatchImplicitDispose(self): + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchImplicitDispose(self, numa_aware): # Tests whether a map and batch dataset will be cleaned up correctly when # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> @@ -266,6 +345,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) dataset = dataset.prefetch(5) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() @@ -274,26 +357,38 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(get_next) @parameterized.named_parameters( - ("1", 0), - ("2", 5), - ("3", 10), - ("4", 90), - ("5", 95), - ("6", 99), + ("1", 0, False), + ("2", 5, False), + ("3", 10, False), + ("4", 90, False), + ("5", 95, False), + ("6", 99, False), + ("1NUMA", 0, True), + ("2NUMA", 5, True), + ("3NUMA", 10, True), + ("4NUMA", 90, True), + ("5NUMA", 95, True), + ("6NUMA", 99, True), ) - def testMapAndBatchOutOfRangeError(self, threshold): + def testMapAndBatchOutOfRangeError(self, threshold, numa_aware): def raising_py_fn(i): - if i >= threshold: + if i == threshold: raise StopIteration() + elif i > threshold: + raise RuntimeError("Alternate error; you shouldn't see me! (i: %s)" % i) else: return i - iterator = ( - dataset_ops.Dataset.range(100).apply( - batching.map_and_batch( - lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), - batch_size=10)).make_one_shot_iterator()) + dataset = dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), + batch_size=10)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.cached_session() as sess: @@ -307,25 +402,42 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(get_next) @parameterized.named_parameters( - ("1", False, dtypes.bool), - ("2", -42, dtypes.int8), - ("3", -42, dtypes.int16), - ("4", -42, dtypes.int32), - ("5", -42, dtypes.int64), - ("6", 42, dtypes.uint8), - ("7", 42, dtypes.uint16), - ("8", 42.0, dtypes.float16), - ("9", 42.0, dtypes.float32), - ("10", 42.0, dtypes.float64), - ("11", b"hello", dtypes.string), + ("1", False, dtypes.bool, False), + ("2", -42, dtypes.int8, False), + ("3", -42, dtypes.int16, False), + ("4", -42, dtypes.int32, False), + ("5", -42, dtypes.int64, False), + ("6", 42, dtypes.uint8, False), + ("7", 42, dtypes.uint16, False), + ("8", 42.0, dtypes.float16, False), + ("9", 42.0, dtypes.float32, False), + ("10", 42.0, dtypes.float64, False), + ("11", b"hello", dtypes.string, False), + ("1NUMA", False, dtypes.bool, True), + ("2NUMA", -42, dtypes.int8, True), + ("3NUMA", -42, dtypes.int16, True), + ("4NUMA", -42, dtypes.int32, True), + ("5NUMA", -42, dtypes.int64, True), + ("6NUMA", 42, dtypes.uint8, True), + ("7NUMA", 42, dtypes.uint16, True), + ("8NUMA", 42.0, dtypes.float16, True), + ("9NUMA", 42.0, dtypes.float32, True), + ("10NUMA", 42.0, dtypes.float64, True), + ("11NUMA", b"hello", dtypes.string, True), ) - def testMapAndBatchTypes(self, element, dtype): + def testMapAndBatchTypes(self, element, dtype, numa_aware): + def gen(): yield element dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( batching.map_and_batch(lambda x: x, batch_size=10)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: @@ -363,6 +475,40 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(iterator.initializer, feed_dict={captured_t: 42}) self.assertAllEqual([42] * 10, sess.run(get_next)) + @parameterized.named_parameters( + ("Normal", False), + ("NUMA", True), + ) + def testMapAndBatchControlFlow(self, numa_aware): + + def map_fn(x): + previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2 + control_flow_ops.ENABLE_COND_V2 = True + return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x) + control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value + return return_value + + dataset = dataset_ops.Dataset.range(100).apply( + batching.map_and_batch(map_fn, batch_size=10)) + if numa_aware: + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.cached_session() as sess: + for i in range(10): + print("Case %d" % i) + if i < 5: + self.assertAllEqual([i * 10 + j + 1 for j in range(10)], + sess.run(get_next)) + else: + self.assertAllEqual( + [((i * 10) + j) * ((i * 10) + j) for j in range(10)], + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index c92bb8b9bc..5a0a73fd83 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -161,6 +161,7 @@ py_test( "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -199,6 +200,7 @@ py_test( deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/experimental/ops:optimization", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py index 82516356df..d38255a6ea 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import time +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import batching @@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class ModelDatasetTest(test_base.DatasetTestBase): +class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testModelMap(self): k = 1024 * 1024 @@ -82,7 +83,11 @@ class ModelDatasetTest(test_base.DatasetTestBase): (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), np.max(deltas))) - def testModelMapAndBatch(self): + @parameterized.named_parameters( + ("Default", False), + ("NUMA", True), + ) + def testModelMapAndBatch(self, numa_aware): batch_size = 16 k = 1024 * 1024 dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), @@ -95,6 +100,8 @@ class ModelDatasetTest(test_base.DatasetTestBase): batch_size=batch_size)) options = dataset_ops.Options() options.experimental_autotune = True + if numa_aware: + options.experimental_numa_aware = True iterator = dataset.with_options(options).make_one_shot_iterator() get_next = iterator.get_next() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py index 760cd8cc4e..2ef29796ab 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testNumaAwareRewrite(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next(["NumaMapAndBatch"])).apply( + batching.map_and_batch(lambda x: x * x, 10)) + options = dataset_ops.Options() + options.experimental_numa_aware = True + dataset = dataset.with_options(options) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range(10).map( lambda _: random_ops.random_uniform([])).batch(10) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD index e556b65b7c..a97cff9fbb 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -307,6 +307,21 @@ py_test( ) py_test( + name = "numa_map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["numa_map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( name = "map_dataset_serialization_test", size = "medium", srcs = ["map_dataset_serialization_test.py"], diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000..04aab329cd --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + ds = dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + options = dataset_ops.Options() + options.experimental_numa_aware = True + return ds.with_options(options) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + ds = dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + options = dataset_ops.Options() + options.experimental_numa_aware = True + return ds.with_options(options) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() + diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 915d399f1b..46a9552b61 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -122,6 +122,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index cf52f7529a..6195747671 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1410,6 +1410,8 @@ class Options(object): "Whether to eliminate no-op transformations."), ("experimental_shuffle_and_repeat_fusion", bool, "Whether to fuse shuffle and repeat transformations."), + ("experimental_numa_aware", bool, + "Whether to use NUMA-aware operations."), ]: def _make_getter(name): # pylint: disable=no-self-argument @@ -1458,6 +1460,9 @@ class Options(object): for exp_opt in experimental_optimizations: if getattr(self, "experimental_" + exp_opt): result.append(exp_opt) + + if getattr(self, "experimental_numa_aware"): + result.append("map_and_batch_numa_aware_replacement") return result def merge(self, options): @@ -1485,7 +1490,7 @@ class Options(object): "experimental_map_and_filter_fusion", "experimental_map_fusion", "experimental_map_parallelization", "experimental_map_vectorization", "experimental_noop_elimination", - "experimental_shuffle_and_repeat_fusion" + "experimental_shuffle_and_repeat_fusion", "experimental_numa_aware", ]: this = getattr(result, name) that = getattr(other, name) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt index d15dccc173..22256996d3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt @@ -43,6 +43,10 @@ tf_class { mtype: "<type \'property\'>" } member { + name: "experimental_numa_aware" + mtype: "<type \'property\'>" + } + member { name: "experimental_shuffle_and_repeat_fusion" mtype: "<type \'property\'>" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt index d15dccc173..22256996d3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt @@ -43,6 +43,10 @@ tf_class { mtype: "<type \'property\'>" } member { + name: "experimental_numa_aware" + mtype: "<type \'property\'>" + } + member { name: "experimental_shuffle_and_repeat_fusion" mtype: "<type \'property\'>" } |