aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-10-09 11:54:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:58:43 -0700
commit072fcb995a3fd658ee2461b59b159498c710513d (patch)
treef3def3d3ac6e270ad32e428889a79d662c8bc9cf
parent12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff)
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt58
-rw-r--r--tensorflow/core/framework/model.h2
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h6
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc62
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h48
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc112
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD17
-rw-r--r--tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc1135
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc38
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc26
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py280
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py11
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py16
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD15
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py95
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py7
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt4
22 files changed, 1909 insertions, 81 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt
new file mode 100644
index 0000000000..243922d969
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalNumaMapAndBatchDataset.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalNumaMapAndBatchDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when building a closure
+for `f`.
+END
+ }
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a
+batch. It determines the number of concurrent invocations of `f` that process
+elements from `input_dataset` in parallel.
+END
+ }
+ in_arg {
+ name: "num_parallel_calls"
+ description: <<END
+A scalar representing the maximum number of parallel invocations of the `map_fn`
+function. Applying the `map_fn` on consecutive input elements in parallel has
+the potential to improve input pipeline throughput.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function to apply to the outputs of `input_dataset`.
+END
+ }
+ summary: "Creates a dataset that fuses mapping with batching."
+ description: <<END
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+Unlike "MapAndBatchDatasetV2", this dataset uses a NUMA-aware thread scheduling
+policy. Because it uses the single-threaded executor, it only supports the
+function-based control flow ops.
+END
+}
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index eae0fa70e8..9596252664 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -335,7 +335,7 @@ class Model {
if (name_ == "Map") {
return Type::MAP;
}
- if (name_ == "MapAndBatch") {
+ if (name_ == "MapAndBatch" || name_ == "NumaMapAndBatch") {
return Type::MAP_AND_BATCH;
}
if (name_ == "PaddedBatch") {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index ee7c14e3ab..1c553044a8 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -415,6 +415,40 @@ tf_cc_test(
)
cc_library(
+ name = "map_and_batch_numa_aware_replacement",
+ srcs = ["map_and_batch_numa_aware_replacement.cc"],
+ hdrs = ["map_and_batch_numa_aware_replacement.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_and_batch_numa_aware_replacement_test",
+ srcs = ["map_and_batch_numa_aware_replacement_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":map_and_batch_numa_aware_replacement",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
name = "noop_elimination",
srcs = ["noop_elimination.cc"],
hdrs = [
@@ -490,6 +524,7 @@ cc_library(
":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
+ ":map_and_batch_numa_aware_replacement",
":map_and_filter_fusion",
":map_fusion",
":map_parallelization",
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index b2eec7220e..1f03c6515c 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
@@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
{"output_types", gtl::ArraySlice<TensorShape>{}}});
}
+NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece batch_size_node_name,
+ StringPiece num_parallel_calls_node_name,
+ StringPiece drop_remainder_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapAndBatchDatasetV2",
+ {string(input_node_name), "", string(batch_size_node_name),
+ string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
} // end namespace graph_tests_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
index ca0fde997d..f7891d5e1f 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -29,6 +29,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
StringPiece function_name = "IsZero");
+NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece batch_size_node_name,
+ StringPiece num_parallel_calls_node_name,
+ StringPiece drop_remainder_node_name,
+ StringPiece function_name = "XTimesTwo");
+
} // end namespace graph_tests_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc
new file mode 100644
index 0000000000..452089eb67
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeNumaAware(const NodeDef& node, MutableGraphView* graph) {
+ NodeDef numa_aware_node = node;
+ graph_utils::SetUniqueGraphNodeName("map_and_batch_numa_aware",
+ graph->GetGraph(), &numa_aware_node);
+ numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
+ return numa_aware_node;
+}
+
+} // namespace
+
+Status MapAndBatchNumaAwareReplacement::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "MapAndBatchDatasetV2") continue;
+
+ auto* numa_node = graph.AddNode(MakeNumaAware(node, &graph));
+ graph.ReplaceInput(node, *numa_node);
+ nodes_to_delete.insert(node.name());
+ }
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchNumaAwareReplacement,
+ "map_and_batch_numa_aware_replacement");
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h
new file mode 100644
index 0000000000..3b2acd288b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class MapAndBatchNumaAwareReplacement : public CustomGraphOptimizer {
+ public:
+ MapAndBatchNumaAwareReplacement() = default;
+ ~MapAndBatchNumaAwareReplacement() override = default;
+
+ string name() const override {
+ return "map_and_batch_numa_aware_replacement";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc
new file mode 100644
index 0000000000..3c5c61d1c2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MapAndBatchNumaAwareReplacementTest, ReplaceSimple) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {
+ NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
+ NDef("num_parallel_calls", "Const", {},
+ {{"value", 5}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", 0}, {"dtype", DT_BOOL}}),
+ graph_tests_utils::MakeMapAndBatchNode(
+ "map_and_batch", "range", "batch_size", "num_parallel_calls",
+ "drop_remainder"),
+ },
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapAndBatchNumaAwareReplacement optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output));
+}
+
+TEST(MapAndBatchNumaAawareReplacementTest, ReplaceWithExtraChild) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {
+ NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
+ NDef("num_parallel_calls", "Const", {},
+ {{"value", 5}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", 0}, {"dtype", DT_BOOL}}),
+ graph_tests_utils::MakeMapAndBatchNode(
+ "map_and_batch", "range", "batch_size", "num_parallel_calls",
+ "drop_remainder"),
+ NDef("cache", "CacheDataset", {"map_and_batch"}, {}),
+ },
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapAndBatchNumaAwareReplacement optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
+
+ int numa_map_and_batch_component_id = graph_utils::FindGraphNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output);
+ auto& numa_map_and_batch_component =
+ output.node(numa_map_and_batch_component_id);
+ EXPECT_EQ(numa_map_and_batch_component.input(0), "range");
+
+ int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
+ auto& cache_node = output.node(cache_id);
+ EXPECT_EQ(cache_node.input(0), numa_map_and_batch_component.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index 43406db3ed..4cf5643bc0 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -103,6 +103,22 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "numa_map_and_batch_dataset_op",
+ srcs = ["numa_map_and_batch_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:inplace_ops",
+ "//tensorflow/core/kernels/data:captured_function",
+ "//tensorflow/core/kernels/data:dataset",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
@@ -132,6 +148,7 @@ tf_kernel_library(
":ignore_errors_dataset_op",
":indexed_dataset",
":lmdb_dataset_op",
+ ":numa_map_and_batch_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc
new file mode 100644
index 0000000000..d83edb9667
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc
@@ -0,0 +1,1135 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#include <atomic>
+#include <utility>
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/inplace_ops_functor.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/numa.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// kWindowSize is the fixed constant controlling the number of batch outputs
+// each NumaWorkerBlock may be processing at a time. This is currently a
+// constant and not user configurable to enable future performance optimizations
+// in the implementation.
+const int64 kWindowSize = 10;
+
+// Define a helper for more consistent logging.
+#define WORKER_VLOG(verbose_level) \
+ VLOG(verbose_level) << "WorkerThread (" << numa_node << ", " << thread_num \
+ << "): "
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit NumaMapAndBatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 batch_size;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
+ OP_REQUIRES(
+ ctx, batch_size > 0,
+ errors::InvalidArgument("batch_size must be greater than zero."));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ /* use_inter_op_parallelism = */ false,
+ &captured_func));
+
+ *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_, func_,
+ std::move(captured_func));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
+ int64 num_parallel_calls, bool drop_remainder,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ batch_size_(batch_size),
+ num_parallel_calls_(num_parallel_calls),
+ drop_remainder_(drop_remainder),
+ output_types_(output_types),
+ output_shapes_(output_shapes),
+ func_(func),
+ captured_func_(std::move(captured_func)) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::NumaMapAndBatch")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "NumaMapAndBatchDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* batch_size_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ Node* drop_remainder_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
+
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, input_graph_node),
+ std::make_pair(2, batch_size_node),
+ std::make_pair(3, num_parallel_calls_node),
+ std::make_pair(4, drop_remainder_node)}, // Single tensor inputs.
+ {std::make_pair(1, other_arguments)}, // Tensor list inputs.
+ {std::make_pair("f", f),
+ std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs
+ output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ mu_(std::make_shared<mutex>()),
+ autotune_cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, autotune_cond_var_)) {
+ }
+
+ ~Iterator() override {
+ mutex_lock l(*mu_);
+ cancelled_ = true;
+ VLOG(3) << "NumaMapAndBatchIterator::~Iterator: cancelling operations.";
+ for (size_t i = 0; i < workers_.size(); ++i) {
+ workers_[i]->manager.Cancel();
+ }
+ VLOG(3) << "NumaMapAndBatchIterator::~Iterator: waiting for threads to "
+ "shut down.";
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(*mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = std::max(1, port::NUMANumNodes());
+ AddTunableParameter(ctx,
+ /* name = */ "parallelism",
+ /* state = */ num_parallel_calls_,
+ /* min = */ num_parallel_calls_->value,
+ /* max = */ port::NumSchedulableCPUs());
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
+ }
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(ctx));
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ auto cleanup = gtl::MakeCleanup(
+ [] { VLOG(3) << "GetNextInternal call returning."; });
+ NumaWorkerBlock* worker = nullptr;
+ {
+ mutex_lock l(*mu_);
+ VLOG(3) << "GetNextInternal call; current block: " << cur_block_;
+ if (global_end_of_input_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(EnsureBackgroundThreadsStarted(ctx));
+ worker = workers_[cur_block_].get();
+ cur_block_ = (cur_block_ + 1) % workers_.size();
+ }
+ TF_RETURN_IF_ERROR(worker->manager.GetBatch(
+ ctx, dataset()->drop_remainder_, &global_end_of_input_, out_tensors,
+ end_of_sequence));
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(*mu_);
+ for (size_t i = 0; i < workers_.size(); ++i) {
+ if (!workers_[i]->manager.Quiesce()) {
+ return errors::Cancelled(
+ "The iterator was deleted before it could reach a "
+ "checkpointable state.");
+ }
+ }
+
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_workers"), workers_.size()));
+
+ for (size_t i = 0; i < workers_.size(); ++i) {
+ size_t index = (cur_block_ + i) % workers_.size();
+ TF_RETURN_IF_ERROR(workers_[index]->manager.Save(writer, this, i));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(*mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 num_workers = -1;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_workers"), &num_workers));
+ // Note: num_workers can be 0 if the iterator wasn't started when
+ // first checkpointed.
+ if (num_workers < 0) {
+ return errors::DataLoss(
+ "When restoring from checkpoint, we encountered a data "
+ "consistency error: num_workers has an invalid value: ",
+ num_workers);
+ }
+ if (port::NUMAEnabled()) {
+ int actual_numa_domains = port::NUMANumNodes();
+ if (actual_numa_domains != num_workers && num_workers > 0) {
+ LOG(WARNING) << "# NUMA domains mismatch when restoring from "
+ "checkpoint: checkpoint has "
+ << num_workers
+ << " NUMA domains, while this host has: "
+ << actual_numa_domains << " NUMA domains.";
+ }
+ }
+ if (num_workers > 1 && !port::NUMAEnabled()) {
+ LOG(WARNING) << "NUMA is not enabled for this process, but restoring "
+ "a checkpoint that assumes "
+ << num_workers << " NUMA domains.";
+ }
+ workers_.resize(num_workers);
+ for (size_t i = 0; i < num_workers; ++i) {
+ workers_[i] = MakeUnique<NumaWorkerBlock>(this);
+ TF_RETURN_IF_ERROR(
+ workers_[i]->manager.Restore(ctx, reader, this, i));
+ }
+ cur_block_ = 0;
+ return Status::OK();
+ }
+
+ private:
+ // NumaBlockManager manages all the state for a set of threads pinned to a
+ // single NUMA domain.
+ //
+ // The methods can be divided into 3 categories based on who should call
+ // them:
+ //
+ // (1) RunnerThread: WaitForInputSpace, PushInputs, SetEndOfInput.
+ // (2) WorkerThread: RetrieveInput, GetBatchTensors.
+ // RecordBatchEntryComplete
+ // (3) Client threads: GetBatch, Cancel, Save, Restore.
+ //
+ // Internally, we manage state in a circular buffer of size `kWindowSize`.
+ // There are 3 pointers into the circular buffer, and must maintain the
+ // following order: (1) next_input_batch_ (corresponding to the next input
+ // batch to be pulled from the input iterator), (2) next_input_
+ // (corresponding to the batch the WorkerThreads should pull from for
+ // their next inputs), and (3) next_output_ corresponding to the next
+ // value to be consumed by the output iterator.
+ //
+ // Methods return errors::Cancelled if the iteration is cancelled before
+ // completing.
+ //
+ // NumaBlockManager is thread safe.
+ class NumaBlockManager {
+ public:
+ explicit NumaBlockManager(Iterator* itr) : itr_(itr) {}
+
+ // WaitForInputSpace blocks until there is space in the circular buffer
+ // to begin processing a new batch of elements.
+ //
+ // Returns true when there is space, false if the Iterator is cancelled.
+ bool WaitForInputSpace(IteratorContext* ctx) {
+ mutex_lock l(mu_);
+
+ size_t next = (next_input_batch_ + 1) % kWindowSize;
+ DCHECK(next < kWindowSize) << next;
+
+ // Wait for space in the circular buffer.
+ while (!cancelled_ && batches_[next].state != BatchState::kEmpty) {
+ VLOG(3) << "Waiting for input space; next: " << next
+ << ", next_output_: " << next_output_
+ << ", next_input_batch_: " << next_input_batch_;
+ itr_->RecordStop(ctx);
+ runner_cond_var_.wait(l);
+ itr_->RecordStart(ctx);
+ }
+ if (cancelled_) {
+ VLOG(3) << "WaitForInputSpace cancelled.";
+ return false;
+ }
+
+ DCHECK(batches_[next].state == BatchState::kEmpty);
+
+ next_input_batch_ = next;
+ return true;
+ }
+
+ // PushInputs sets the inputs for the next batch as retrieved from the
+ // input iterator.
+ void PushInputs(const Status& status,
+ std::vector<std::vector<Tensor>> inputs) {
+ mutex_lock l(mu_);
+
+ DCHECK(next_input_ < kWindowSize) << next_input_;
+ DCHECK(batches_[next_input_batch_].state == BatchState::kEmpty);
+ DCHECK(batches_[next_input_batch_].next_input_to_process == 0)
+ << batches_[next_input_batch_].next_input_to_process;
+ DCHECK(batches_[next_input_batch_].status.ok())
+ << batches_[next_input_batch_].status;
+
+ batches_[next_input_batch_].inputs.swap(inputs);
+ batches_[next_input_batch_].state = BatchState::kInputsFilled;
+ batches_[next_input_batch_].status.Update(status);
+ if (batches_[next_input_batch_].status.ok()) {
+ worker_cond_var_.notify_all();
+ } else {
+ client_cond_var_.notify_all();
+ batches_[next_input_batch_].error_index = 0;
+ }
+ }
+
+ // SetEndOfInput records the fact that we have reached the end of the
+ // input iterator, and that we should return end_of_sequence = true when
+ // we have exhaused all buffered batches.
+ void SetEndOfInput() {
+ mutex_lock l(mu_);
+ reached_eof_ = true;
+ worker_cond_var_.notify_all();
+ client_cond_var_.notify_all();
+ }
+
+ // RetrieveInput gets the next input tuple to be mapped by a worker
+ // thread.
+ //
+ // Returns true if an input was retrieved, false if the iterator has
+ // been cancelled.
+ bool RetrieveInput(IteratorContext* ctx, std::vector<Tensor>* input,
+ uint64* index, size_t* sequence_number) {
+ mutex_lock l(mu_);
+
+ // Wait for inputs to be ready.
+ while (!cancelled_ &&
+ batches_[next_input_].state != BatchState::kInputsFilled) {
+ itr_->RecordStop(ctx);
+ worker_cond_var_.wait(l);
+ itr_->RecordStart(ctx);
+ }
+
+ if (cancelled_) {
+ return false;
+ }
+
+ DCHECK(batches_[next_input_].next_input_to_process <
+ batches_[next_input_].inputs.size())
+ << "next_input_: " << next_input_ << ", next_input_to_process: "
+ << batches_[next_input_].next_input_to_process
+ << ", inputs.size(): " << batches_[next_input_].inputs.size()
+ << ", state: " << static_cast<int32>(batches_[next_input_].state)
+ << ", this: " << this;
+ *index = batches_[next_input_].next_input_to_process;
+ *sequence_number = next_input_;
+ input->swap(batches_[next_input_]
+ .inputs[batches_[next_input_].next_input_to_process]);
+ // Increment pointers.
+ batches_[next_input_].next_input_to_process++;
+
+ if (batches_[next_input_].next_input_to_process ==
+ batches_[next_input_].inputs.size()) {
+ batches_[next_input_].state = BatchState::kAllMapsStarted;
+ next_input_ = (next_input_ + 1) % kWindowSize;
+ }
+ return true;
+ }
+
+ // GetBatchTensors returns a pointer to the output batch tensors for the
+ // worker thread to copy into.
+ //
+ // allocate_output is a function taking a batch size, and a pointer to
+ // the output tuple of Tensors to allocate them. The allocate_output
+ // function is called at most once per output batch.
+ std::vector<Tensor>* GetBatchTensors(
+ size_t sequence_number,
+ std::function<void(size_t, std::vector<Tensor>*)> allocate_output) {
+ mutex_lock l(mu_);
+ DCHECK(sequence_number < kWindowSize) << sequence_number;
+ DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled ||
+ batches_[sequence_number].state == BatchState::kAllMapsStarted)
+ << sequence_number;
+
+ if (batches_[sequence_number].outputs.empty()) {
+ allocate_output(batches_[sequence_number].inputs.size(),
+ &batches_[sequence_number].outputs);
+ }
+ return &batches_[sequence_number].outputs;
+ }
+
+ // RecordBatchEntryComplete records an element of the batch has finished
+ // copying into the output tensors.
+ void RecordBatchEntryComplete(size_t sequence_number, uint64 index,
+ Status s) {
+ mutex_lock l(mu_);
+ DCHECK(sequence_number < kWindowSize) << sequence_number;
+ DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled ||
+ batches_[sequence_number].state == BatchState::kAllMapsStarted)
+ << sequence_number;
+
+ batches_[sequence_number].num_outputs_complete++;
+ if (!s.ok() && batches_[sequence_number].error_index > index) {
+ batches_[sequence_number].status = s;
+ batches_[sequence_number].error_index = index;
+ }
+
+ if (batches_[sequence_number].num_outputs_complete ==
+ batches_[sequence_number].inputs.size()) {
+ DCHECK(batches_[sequence_number].state ==
+ BatchState::kAllMapsStarted);
+ batches_[sequence_number].state = BatchState::kOutputsComplete;
+ batches_[sequence_number].inputs.clear(); // Eagerly save memory.
+ batches_[sequence_number].inputs.shrink_to_fit();
+ client_cond_var_.notify_all();
+ }
+ }
+
+ // GetBatch retrieves the next output batch tensors.
+ Status GetBatch(IteratorContext* ctx, bool drop_remainder,
+ bool* global_eof, std::vector<Tensor>* out_tensor,
+ bool* end_of_sequence) {
+ mutex_lock l(mu_);
+ // Wait until one of 3 conditions occurs:
+ // (1) we're cancelled.
+ // (2) the state becomes kOutputsComplete
+ // (3) state is empty && reached_eof.
+ while (!cancelled_ &&
+ batches_[next_output_].state != BatchState::kOutputsComplete &&
+ !(reached_eof_ &&
+ batches_[next_output_].state == BatchState::kEmpty)) {
+ VLOG(3) << "Waiting in GetBatch.";
+ itr_->RecordStop(ctx);
+ client_cond_var_.wait(l);
+ itr_->RecordStart(ctx);
+ }
+
+ if (cancelled_) {
+ return errors::Cancelled(
+ "Cancelled in NumaMapAndBatch::GetNext call.");
+ }
+
+ if (reached_eof_ &&
+ batches_[next_output_].state == BatchState::kEmpty) {
+ VLOG(4) << "GetBatch returning end of sequence.";
+ *end_of_sequence = true;
+ *global_eof = true;
+ return Status::OK();
+ }
+
+ VLOG(3) << "Returning output index: " << next_output_
+ << ", this: " << this;
+
+ *end_of_sequence = false;
+ Status s = batches_[next_output_].status;
+ if (s.ok()) {
+ out_tensor->swap(batches_[next_output_].outputs);
+ }
+ // Handle early termination.
+ if (errors::IsOutOfRange(s)) {
+ *global_eof = true;
+ s = Status::OK();
+ if (drop_remainder || batches_[next_output_].error_index == 0) {
+ *end_of_sequence = true;
+ } else {
+ std::vector<Tensor> true_outputs;
+ for (size_t i = 0; i < batches_[next_output_].outputs.size();
+ ++i) {
+ TensorShape component_shape(
+ batches_[next_output_].outputs[i].shape());
+ component_shape.set_dim(0, batches_[next_output_].error_index);
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ Tensor component(ctx->allocator(attr),
+ batches_[next_output_].outputs[i].dtype(),
+ component_shape);
+ TF_RETURN_IF_ERROR(CopyPartialBatch(
+ &component, batches_[next_output_].outputs[i],
+ batches_[next_output_].error_index));
+ true_outputs.emplace_back(std::move(component));
+ }
+ out_tensor->swap(true_outputs);
+ }
+ }
+
+ batches_[next_output_].Reset();
+ next_output_ = (next_output_ + 1) % kWindowSize;
+ runner_cond_var_.notify_all();
+
+ return s;
+ }
+
+ void Cancel() {
+ mutex_lock l(mu_);
+ VLOG(3) << "Cancelling NUMA block.";
+ cancelled_ = true;
+ runner_cond_var_.notify_all();
+ worker_cond_var_.notify_all();
+ client_cond_var_.notify_all();
+ }
+
+ // Waits until all the worker threads have completed their work and all
+ // internal state has reached a "safe-point" where we can safely
+ // checkpoint.
+ //
+ // Returns true if completed successfully, false if cancelled while
+ // waiting.
+ bool Quiesce() {
+ mutex_lock l(mu_);
+ VLOG(3) << "Waiting until the operations have quiesced.";
+ while (!cancelled_ && !AllMapOperationsFinished()) {
+ client_cond_var_.wait(l);
+ }
+ if (cancelled_) {
+ return false;
+ }
+ return true;
+ }
+
+ Status Save(IteratorStateWriter* writer, Iterator* itr, size_t index) {
+ mutex_lock l(mu_);
+ string prefix = itr->full_name(strings::StrCat("numa_block_", index));
+ if (reached_eof_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ strings::StrCat(prefix, "_end_of_input"), ""));
+ }
+ for (size_t i = 0; i < kWindowSize; ++i) {
+ size_t index = (next_output_ + i) % kWindowSize;
+ if (batches_[index].state == BatchState::kEmpty) {
+ break;
+ }
+ string batch_prefix = strings::StrCat(prefix, "_batch_", i);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ strings::StrCat(batch_prefix, "_code"),
+ static_cast<int64>(batches_[index].status.code())));
+ if (!batches_[index].status.ok()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat(batch_prefix, "_msg"),
+ batches_[index].status.error_message()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ strings::StrCat(batch_prefix, "_error_index"),
+ batches_[index].error_index));
+ }
+
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ strings::StrCat(batch_prefix, "_output_size"),
+ batches_[index].outputs.size()));
+ for (size_t j = 0; j < batches_[index].outputs.size(); ++j) {
+ string tensor_prefix =
+ strings::StrCat(batch_prefix, "_output_", j);
+ if (!batches_[index].status.ok()) {
+ DCHECK(batches_[index].error_index >= 0 &&
+ batches_[index].error_index <
+ itr_->dataset()->batch_size_);
+ // If the batch is not full, we only store the first
+ // `error_index` values. The rest of the batch tensor might not
+ // be initialized, and accessing that will raise msan errors.
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ tensor_prefix, batches_[index].outputs[j].Slice(
+ 0, batches_[index].error_index)));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ tensor_prefix, batches_[index].outputs[j]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Restore(IteratorContext* ctx, IteratorStateReader* reader,
+ Iterator* itr, size_t index) {
+ mutex_lock l(mu_);
+ if (reached_eof_) {
+ return errors::FailedPrecondition(
+ "Already reached the end of the sequence.");
+ }
+ string prefix = itr->full_name(strings::StrCat("numa_block_", index));
+ reached_eof_ =
+ reader->Contains(strings::StrCat(prefix, "_end_of_input"));
+ for (size_t i = 0; i < kWindowSize; ++i) {
+ string batch_prefix = strings::StrCat(prefix, "_batch_", i);
+ if (!reader->Contains(strings::StrCat(batch_prefix, "_code"))) {
+ break;
+ }
+ Batch batch;
+ batch.state = BatchState::kOutputsComplete;
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat(batch_prefix, "_code"), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat(batch_prefix, "_msg"), &error_message));
+ batch.status = Status(code, error_message);
+ int64 error_index_int = -1;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat(batch_prefix, "_error_index"),
+ &error_index_int));
+ if (error_index_int < 0 ||
+ error_index_int > itr->dataset()->batch_size_) {
+ return errors::FailedPrecondition(
+ "Error index out of bounds when restoring from checkpoint; "
+ "error index: ",
+ error_index_int);
+ }
+ batch.error_index = static_cast<size_t>(error_index_int);
+ }
+ int64 output_size = -1;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat(batch_prefix, "_output_size"), &output_size));
+ batch.outputs.reserve(output_size);
+ for (size_t j = 0; j < output_size; ++j) {
+ string tensor_name = strings::StrCat(batch_prefix, "_output_", j);
+ Tensor t;
+ TF_RETURN_IF_ERROR(reader->ReadTensor(tensor_name, &t));
+ batch.outputs.emplace_back(std::move(t));
+ }
+ batches_[i] = std::move(batch);
+ }
+ return Status::OK();
+ }
+
+ private:
+ bool AllMapOperationsFinished() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (size_t i = 0; i < kWindowSize; ++i) {
+ if (batches_[i].state == BatchState::kInputsFilled ||
+ batches_[i].state == BatchState::kAllMapsStarted) {
+ return false;
+ }
+ if (batches_[i].state != BatchState::kOutputsComplete &&
+ !reached_eof_) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Batches begin in the `kEmpty` state. Once the RunnerThread has
+ // filled the `inputs` to a `Batch`, it transitions to the
+ // `kInputsFilled` state. At this point, the Worker threads run the map
+ // function and copy the outputs appropriately. Once all worker threads
+ // have started, it transitions to `kAllMapsStarted`. After the outputs
+ // are complete, the GetNext call can consume the outputs, and return
+ // the batch to the kEmpty state.
+ enum class BatchState {
+ kEmpty,
+ kInputsFilled,
+ kAllMapsStarted,
+ kOutputsComplete,
+ };
+
+ // Batch captures all the state of an output batch as it progresses
+ // through the machinery. Once the RunnerThread fills inputs, it
+ // transitions to `kInputsFilled`. At this point, the worker threads can
+ // work on it, incrementing outputs_complete for every element of the
+ // input set that is copied into the output Tensors. Once all the input
+ // tuples have been processed (i.e. num_outputs_complete ==
+ // inputs.size()), it transitions to the `kOutputsComplete` stage, where
+ // it is ready to be returned by a `GetBatch` call (called from
+ // `GetNextInternal`).
+ struct Batch {
+ BatchState state;
+ // Aggregates the Status of the input iterator's GetNext
+ // calls, in addition to the Status of the map function invocations.
+ //
+ // In the case where multiple non-OK statuses are encountered, we
+ // return the first one encountered.
+ Status status;
+ // In order to return the correct error status, we keep track of the
+ // error_index.
+ size_t error_index;
+ // The batch_size input tuples (or fewer in the case of the last
+ // batch).
+ // TODO(saeta): Avoid re-allocating vectors all the time!
+ std::vector<std::vector<Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ size_t next_input_to_process;
+ size_t num_outputs_complete;
+
+ Batch() { Reset(); }
+
+ // Resets the Batch state (e.g. after consuming the outputs).
+ void Reset() {
+ state = BatchState::kEmpty;
+ status = Status::OK();
+ inputs.clear();
+ inputs.shrink_to_fit();
+ outputs.clear();
+ outputs.shrink_to_fit();
+ next_input_to_process = 0;
+ num_outputs_complete = 0;
+ error_index = -1;
+ }
+ };
+
+ Iterator* itr_; // Not owned.
+ mutex mu_;
+ Batch batches_[kWindowSize] GUARDED_BY(mu_);
+ size_t next_input_batch_ GUARDED_BY(mu_) = -1;
+ size_t next_input_ GUARDED_BY(mu_) = 0;
+ size_t next_output_ GUARDED_BY(mu_) = 0;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ bool reached_eof_ GUARDED_BY(mu_) = false;
+
+ // The runner thread waits on this condition variable for space to be
+ // available. When the client thread takes a value out of the circular
+ // buffer, it notifies this condition variable that space is now
+ // available.
+ condition_variable runner_cond_var_ GUARDED_BY(mu_);
+ // The worker threads wait on this condition variable for available
+ // inputs. When the runner thread makes new inputs available, it
+ // notifies this condition variable.
+ condition_variable worker_cond_var_ GUARDED_BY(mu_);
+ // The client threads wait on this condition variable for avaiable
+ // batched outputs. When worker threads complete a batch, they notify
+ // this condition variable.
+ condition_variable client_cond_var_ GUARDED_BY(mu_);
+ };
+ // Mark NumaBlockManager as a friend of Iterator in order to call
+ // protected Iterator methods during checkpointing.
+ friend NumaBlockManager;
+
+ struct NumaWorkerBlock {
+ NumaBlockManager manager;
+ // TODO(saeta): Migrate to BackgroundWorker.
+ std::vector<std::unique_ptr<Thread>> threads;
+
+ explicit NumaWorkerBlock(Iterator* itr) : manager(itr) {}
+ };
+
+ static void CustomNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) {
+ ptr->~NumaWorkerBlock();
+ port::NUMAFree(ptr, sizeof(NumaWorkerBlock));
+ }
+ static void DefaultNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) {
+ delete ptr;
+ }
+
+ static Status CopyPartialBatch(Tensor* output, const Tensor& value,
+ int64 num_elements) {
+ switch (value.dtype()) {
+#define HANDLE_TYPE(type) \
+ case DataTypeToEnum<type>::value: { \
+ auto output_t = output->flat_outer_dims<type>(); \
+ auto value_t = value.flat_outer_dims<type>(); \
+ for (size_t i = 0; i < num_elements; i++) { \
+ output_t.template chip<0>(i) = value_t.template chip<0>(i); \
+ } \
+ return Status::OK(); \
+ }
+ TF_CALL_DATASET_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ return errors::InvalidArgument("Unsupported data type: ",
+ DataTypeString(value.dtype()));
+ }
+ return Status::OK();
+ }
+
+ Status EnsureBackgroundThreadsStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ if (curr_num_parallel_calls_ >= num_parallel_calls_->value) {
+ // All necessary threads have been started.
+ curr_num_parallel_calls_ = num_parallel_calls_->value;
+ return Status::OK();
+ }
+
+ VLOG(4) << "Starting workers";
+ bool numa_enabled = port::NUMAEnabled();
+
+ if (!numa_enabled) {
+ LOG(INFO) << "NUMA not enabled on this host.";
+ }
+
+ int num_numa_nodes = port::NUMANumNodes();
+ if (num_numa_nodes < 1) {
+ return errors::Internal("The number of NUMA nodes is invalid: ",
+ num_numa_nodes);
+ }
+
+ // Only resize when empty to support restoring from checkpoints.
+ if (workers_.empty()) {
+ VLOG(3) << "# NUMA Nodes: " << num_numa_nodes
+ << ", # Parallel Calls: " << num_parallel_calls_->value;
+ workers_.resize(num_numa_nodes);
+ } else {
+ num_numa_nodes = workers_.size();
+ }
+
+ // Round up num_parallel_calls, with a minimum of 1.
+ const size_t num_threads_per_block =
+ std::max(1LL, (num_parallel_calls_->value + num_numa_nodes - 1) /
+ num_numa_nodes);
+
+ VLOG(3) << "Starting " << num_threads_per_block * num_numa_nodes
+ << " worker threads, with " << num_threads_per_block
+ << " threads per block.";
+
+ // Only allocate new_ctx if required.
+ std::shared_ptr<IteratorContext> new_ctx;
+
+ for (int i = 0; i < num_numa_nodes; ++i) {
+ if (!workers_[i]) {
+ if (numa_enabled) {
+ // Allocate in appropriate NUMA domain.
+ // 4k page align.
+ void* ptr = port::NUMAMalloc(i, sizeof(NumaWorkerBlock), 0);
+ if (ptr != nullptr) {
+ NumaWorkerBlock* block = new (ptr) NumaWorkerBlock(this);
+ workers_[i] =
+ std::unique_ptr<NumaWorkerBlock,
+ std::function<void(NumaWorkerBlock*)>>(
+ block, CustomNumaWorkerBlockDeleter);
+ } else {
+ LOG(ERROR) << "Could not NUMA-allocate worker block: " << i;
+ }
+ }
+ // If the NUMA allocation fails, or NUMA is not enabled.
+ if (!workers_[i]) {
+ workers_[i] =
+ std::unique_ptr<NumaWorkerBlock,
+ std::function<void(NumaWorkerBlock*)>>(
+ new NumaWorkerBlock(this), DefaultNumaWorkerBlockDeleter);
+ }
+ }
+ // Be sure to start threads if num_parallel_calls_ has changed.
+ for (size_t j = workers_[i]->threads.size();
+ j < num_threads_per_block; ++j) {
+ VLOG(3) << "Starting worker " << i << ", " << j;
+ if (!new_ctx) {
+ new_ctx = std::make_shared<IteratorContext>(*ctx);
+ }
+ workers_[i]->threads.emplace_back(ctx->env()->StartThread(
+ {},
+ strings::StrCat("numa_map_and_batch_block_", i, "_thread_", j),
+ [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); }));
+ VLOG(3) << "Worker " << i << ", " << j << " successfully started.";
+ }
+ }
+ if (!runner_thread_) {
+ if (!new_ctx) {
+ new_ctx = std::make_shared<IteratorContext>(*ctx);
+ }
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "numa_map_runner_thread",
+ [this, new_ctx] { RunnerThread(new_ctx); }));
+ }
+ VLOG(3) << "All workers & runner thread started.";
+ return Status::OK();
+ }
+
+ void AllocateOutput(IteratorContext* ctx, size_t batch_size,
+ const std::vector<Tensor>& map_fn_outputs,
+ std::vector<Tensor>* batch_outputs) {
+ DCHECK(dataset()->output_dtypes().size() ==
+ dataset()->output_shapes().size());
+ DCHECK(map_fn_outputs.size() == dataset()->output_dtypes().size());
+ for (size_t i = 0; i < dataset()->output_dtypes().size(); ++i) {
+ TensorShape component_shape({static_cast<uint32>(batch_size)});
+ component_shape.AppendShape(map_fn_outputs.at(i).shape());
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ Tensor component(ctx->allocator(attr), map_fn_outputs.at(i).dtype(),
+ component_shape);
+ batch_outputs->emplace_back(std::move(component));
+ }
+ }
+
+ void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+ LOCKS_EXCLUDED(mu_) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, &ctx] {
+ // Set end of input on all the managers in order to clean up in an
+ // orderly fashion.
+ VLOG(3) << "Setting End of Input on workers_[*]->manager";
+ for (size_t i = 0; i < workers_.size(); ++i) {
+ workers_[i]->manager.SetEndOfInput();
+ }
+ RecordStop(ctx.get());
+ });
+
+ const size_t num_blocks = workers_.size();
+
+ while (true) {
+ for (size_t block = 0; block < num_blocks; ++block) {
+ VLOG(4) << "RunnerThread waiting for input space in block: "
+ << block;
+ if (TF_PREDICT_FALSE(
+ !workers_[block]->manager.WaitForInputSpace(ctx.get()))) {
+ VLOG(3) << "RunnerThread exiting due to cancellation.";
+ return;
+ }
+ VLOG(4) << "RunnerThread has space; pulling on upstream for block "
+ << block;
+
+ Status s;
+ std::vector<std::vector<Tensor>> inputs;
+ bool end_of_sequence = false;
+ for (size_t i = 0; i < dataset()->batch_size_; ++i) {
+ std::vector<Tensor> tuple;
+ s.Update(
+ input_impl_->GetNext(ctx.get(), &tuple, &end_of_sequence));
+ if (!s.ok()) {
+ break;
+ }
+ if (end_of_sequence) {
+ VLOG(4) << "Runner thread encountered end of sequence.";
+ if (dataset()->drop_remainder_) {
+ return;
+ }
+ break;
+ }
+ inputs.push_back(std::move(tuple));
+ }
+
+ VLOG(4) << "Moving inputs to block " << block
+ << ", which has size: " << inputs.size();
+ if (!s.ok() || !inputs.empty()) {
+ workers_[block]->manager.PushInputs(s, std::move(inputs));
+ VLOG(4) << "Inputs moved into block " << block;
+ }
+ if (end_of_sequence) {
+ return;
+ }
+ }
+ }
+ }
+
+ void WorkerThread(std::shared_ptr<IteratorContext> ctx,
+ const int numa_node, const int thread_num) {
+ RecordStart(ctx.get());
+ WORKER_VLOG(3) << "started.";
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, numa_node, thread_num, &ctx]() {
+ RecordStop(ctx.get());
+ WORKER_VLOG(3) << "exiting.";
+ });
+
+ NumaWorkerBlock* block = workers_[numa_node].get();
+ port::NUMASetThreadNodeAffinity(numa_node);
+ const int num_numa_nodes = port::NUMANumNodes();
+ const int minimum_num_parallel_calls = thread_num * num_numa_nodes;
+
+ while (true) {
+ // Put threads to sleep based on autotuner.
+ {
+ mutex_lock l(*mu_);
+ while (minimum_num_parallel_calls >= num_parallel_calls_->value &&
+ !cancelled_) {
+ RecordStop(ctx.get());
+ autotune_cond_var_->wait(l);
+ RecordStart(ctx.get());
+ }
+ if (cancelled_) {
+ return;
+ }
+ }
+
+ std::vector<Tensor> input;
+ uint64 index = 0;
+ size_t sequence_number = 0;
+ WORKER_VLOG(4) << "retrieving input.";
+ {
+ tracing::ScopedActivity trace(
+ "NumaMapAndBatch::Iterator::Worker::RetrieveInput");
+ if (!block->manager.RetrieveInput(ctx.get(), &input, &index,
+ &sequence_number)) {
+ return;
+ }
+ }
+
+ WORKER_VLOG(4) << "retrieved input; index: " << index
+ << ", sequence_number: " << sequence_number;
+
+ std::vector<Tensor> return_values;
+ Status s;
+ {
+ tracing::ScopedActivity trace(
+ "NumaMapAndBatch::Iterator::Worker::FunctionExecution");
+ s = dataset()->captured_func_->Run(ctx.get(), std::move(input),
+ &return_values);
+ }
+ WORKER_VLOG(4) << "ran function for index: " << index
+ << ", sequence_number: " << sequence_number;
+
+ if (s.ok()) {
+ std::vector<Tensor>* output = block->manager.GetBatchTensors(
+ sequence_number,
+ [this, ctx, &return_values](size_t batch_size,
+ std::vector<Tensor>* output) {
+ AllocateOutput(ctx.get(), batch_size, return_values, output);
+ });
+ WORKER_VLOG(4) << "copying tensors to batch output.";
+ {
+ tracing::ScopedActivity trace(
+ "NumaMapAndBatch::Iterator::Worker::BatchCopy");
+ for (size_t i = 0; i < return_values.size() && s.ok(); ++i) {
+ Tensor& tensor = return_values.at(i);
+ Tensor* batch = &output->at(i);
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ s.Update(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch->shape().DebugString()));
+ break;
+ }
+ s.Update(batch_util::CopyElementToSlice(std::move(tensor),
+ batch, index));
+ }
+ }
+ }
+
+ block->manager.RecordBatchEntryComplete(sequence_number, index, s);
+ WORKER_VLOG(4) << "finished index: " << index
+ << ", sequence_number: " << sequence_number;
+ }
+ }
+
+ // mu_ protects shared internal state and is used to coordinate between
+ // the auto-tuner, client threads, worker threads, and the runner thread.
+ const std::shared_ptr<mutex> mu_;
+ const std::shared_ptr<condition_variable> autotune_cond_var_;
+ // The maximum number of parallel calls (can be auto-tuned).
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
+
+ // Caches the last-seen value of num_parallel_calls_->value to
+ // short-circuit starting workers.
+ int64 curr_num_parallel_calls_ GUARDED_BY(*mu_) = 0;
+
+ std::unique_ptr<IteratorBase> input_impl_;
+ int64 cur_block_ GUARDED_BY(*mu_) = 0;
+ bool global_end_of_input_ GUARDED_BY(*mu_) = false;
+ bool cancelled_ GUARDED_BY(*mu_) = false;
+ std::vector<std::unique_ptr<NumaWorkerBlock,
+ std::function<void(NumaWorkerBlock*)>>>
+ workers_; // Const after initialization.
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ };
+
+ const DatasetBase* const input_;
+ const int64 batch_size_;
+ const int64 num_parallel_calls_;
+ const bool drop_remainder_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const NameAttrList func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList func_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalNumaMapAndBatchDataset").Device(DEVICE_CPU),
+ NumaMapAndBatchDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index f45a239793..bae56828dc 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -324,6 +324,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
+ // BatchResult encapsulates the output batch, as well as anciliary
+ // metadata required to execute the fused map-and-batch operation.
struct BatchResult {
explicit BatchResult(int64 batch_size) {
end_of_input = false;
@@ -331,11 +333,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
num_elements = 0;
output_allocated = false;
status = Status::OK();
+ status_offset = -1;
}
- void UpdateStatus(const Status& s) {
- mutex_lock l(mu);
- status.Update(s);
+ // UpdateStatus updates the batch's aggregate Status.
+ //
+ // In order to ensure that exactly the first non-OK status is returned
+ // (required to make the behavior is observably identical to a
+ // sequential execution of map followed by batch), we must also keep
+ // track of the offset into the batch that produced `s`.
+ void UpdateStatus(const Status& s, int64 offset) {
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ mutex_lock l(mu);
+ if (status.ok() || offset < status_offset) {
+ status = s;
+ status_offset = offset;
+ }
+ }
}
mutex mu;
@@ -344,6 +358,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> output;
bool output_allocated GUARDED_BY(mu);
Status status GUARDED_BY(mu);
+ int64 status_offset GUARDED_BY(mu);
// Counts the number of outstanding calls for this batch.
int64 num_calls; // access guarded by owner's mutex
};
@@ -379,7 +394,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::shared_ptr<std::vector<Tensor>> return_values =
std::make_shared<std::vector<Tensor>>();
auto done = [this, ctx, result, return_values, offset](Status status) {
- result->UpdateStatus(status);
+ result->UpdateStatus(status, offset);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
for (size_t i = 0; i < return_values->size(); ++i) {
@@ -389,11 +404,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
(batch->NumElements() / batch->dim_size(0))) {
TensorShape batch_shape = batch->shape();
batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does "
- "not match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
+ result->UpdateStatus(
+ errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements "
+ "does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()),
+ offset);
break;
}
// TODO(mrry): Add a version of DoParallelConcat that allows us to
@@ -402,7 +420,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status copy_status = ::tensorflow::functor::DoParallelConcat(
*dataset()->device_, tensor, offset, batch);
if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
+ result->UpdateStatus(copy_status, offset);
break;
}
}
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index f6bd5dce26..bbbecc50f8 100644
--- a/tensorflow/core/ops/experimental_dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -138,6 +138,32 @@ REGISTER_OP("ExperimentalAssertNextDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("ExperimentalNumaMapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_calls: int64")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Use index from the end to retrieve the Input shapes,
+ // so that to avoid guessing the length of "other_arguments".
+ // batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index d444c4082e..5ead6d1c75 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -38,12 +39,17 @@ from tensorflow.python.platform import test
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
+ ("Default", None, None, False),
+ ("SequentialCalls", 1, None, False),
+ ("ParallelCalls", 2, None, False),
+ ("ParallelBatches", None, 10, False),
+ ("DefaultNUMA", None, None, True),
+ ("SequentialCallsNUMA", 1, None, True),
+ ("ParallelCallsNUMA", 2, None, True),
+ ("ParallelBatchesNUMA", None, 10, True),
)
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches,
+ numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
@@ -57,14 +63,20 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (
+ dataset = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
batching.map_and_batch(
map_func=_map_fn,
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
+ num_parallel_batches=num_parallel_batches)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -115,16 +127,25 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
+ ("Even", False, False),
+ ("Uneven", True, False),
+ ("EvenNUMA", False, True),
+ ("UnevenNUMA", True, True),
)
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
+ def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
+ dataset = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]),
batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
+ drop_remainder=drop_remainder)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
if drop_remainder:
self.assertEqual([4, 1], iterator.output_shapes.as_list())
else:
@@ -138,11 +159,21 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchYieldsPartialBatch(self, numa_aware):
+ dataset = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4)))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_one_shot_iterator()
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
@@ -152,10 +183,19 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchParallelGetNext(self, numa_aware):
+ dataset = dataset_ops.Dataset.range(50000).apply(
+ batching.map_and_batch(lambda x: x, batch_size=100))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
elements = []
for _ in range(100):
elements.append(iterator.get_next())
@@ -165,17 +205,26 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
+ dataset = dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+
elements = []
for _ in range(100):
elements.append(iterator.get_next())
@@ -185,19 +234,29 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
- def testMapAndBatchSparse(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchSparse(self, numa_aware):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ dataset = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -214,21 +273,33 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testMapAndBatchFails(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchFails(self, numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
+ dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
- def testMapAndBatchShapeMismatch(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchShapeMismatch(self, numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
@@ -240,9 +311,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int32)
batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
+ dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -251,7 +326,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
"number of elements does not match"):
sess.run(get_next)
- def testMapAndBatchImplicitDispose(self):
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchImplicitDispose(self, numa_aware):
# Tests whether a map and batch dataset will be cleaned up correctly when
# the pipeline does not run it until exhaustion.
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
@@ -266,6 +345,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
dataset = dataset.prefetch(5)
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -274,26 +357,38 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(get_next)
@parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
+ ("1", 0, False),
+ ("2", 5, False),
+ ("3", 10, False),
+ ("4", 90, False),
+ ("5", 95, False),
+ ("6", 99, False),
+ ("1NUMA", 0, True),
+ ("2NUMA", 5, True),
+ ("3NUMA", 10, True),
+ ("4NUMA", 90, True),
+ ("5NUMA", 95, True),
+ ("6NUMA", 99, True),
)
- def testMapAndBatchOutOfRangeError(self, threshold):
+ def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):
def raising_py_fn(i):
- if i >= threshold:
+ if i == threshold:
raise StopIteration()
+ elif i > threshold:
+ raise RuntimeError("Alternate error; you shouldn't see me! (i: %s)" % i)
else:
return i
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
+ dataset = dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -307,25 +402,42 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(get_next)
@parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
+ ("1", False, dtypes.bool, False),
+ ("2", -42, dtypes.int8, False),
+ ("3", -42, dtypes.int16, False),
+ ("4", -42, dtypes.int32, False),
+ ("5", -42, dtypes.int64, False),
+ ("6", 42, dtypes.uint8, False),
+ ("7", 42, dtypes.uint16, False),
+ ("8", 42.0, dtypes.float16, False),
+ ("9", 42.0, dtypes.float32, False),
+ ("10", 42.0, dtypes.float64, False),
+ ("11", b"hello", dtypes.string, False),
+ ("1NUMA", False, dtypes.bool, True),
+ ("2NUMA", -42, dtypes.int8, True),
+ ("3NUMA", -42, dtypes.int16, True),
+ ("4NUMA", -42, dtypes.int32, True),
+ ("5NUMA", -42, dtypes.int64, True),
+ ("6NUMA", 42, dtypes.uint8, True),
+ ("7NUMA", 42, dtypes.uint16, True),
+ ("8NUMA", 42.0, dtypes.float16, True),
+ ("9NUMA", 42.0, dtypes.float32, True),
+ ("10NUMA", 42.0, dtypes.float64, True),
+ ("11NUMA", b"hello", dtypes.string, True),
)
- def testMapAndBatchTypes(self, element, dtype):
+ def testMapAndBatchTypes(self, element, dtype, numa_aware):
+
def gen():
yield element
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
batching.map_and_batch(lambda x: x, batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
@@ -363,6 +475,40 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(iterator.initializer, feed_dict={captured_t: 42})
self.assertAllEqual([42] * 10, sess.run(get_next))
+ @parameterized.named_parameters(
+ ("Normal", False),
+ ("NUMA", True),
+ )
+ def testMapAndBatchControlFlow(self, numa_aware):
+
+ def map_fn(x):
+ previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
+ control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
+ return return_value
+
+ dataset = dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ if numa_aware:
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for i in range(10):
+ print("Case %d" % i)
+ if i < 5:
+ self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
+ sess.run(get_next))
+ else:
+ self.assertAllEqual(
+ [((i * 10) + j) * ((i * 10) + j) for j in range(10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index c92bb8b9bc..5a0a73fd83 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -161,6 +161,7 @@ py_test(
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -199,6 +200,7 @@ py_test(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
index 82516356df..d38255a6ea 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import time
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import batching
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test_base.DatasetTestBase):
+class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testModelMap(self):
k = 1024 * 1024
@@ -82,7 +83,11 @@ class ModelDatasetTest(test_base.DatasetTestBase):
(np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
np.max(deltas)))
- def testModelMapAndBatch(self):
+ @parameterized.named_parameters(
+ ("Default", False),
+ ("NUMA", True),
+ )
+ def testModelMapAndBatch(self, numa_aware):
batch_size = 16
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
@@ -95,6 +100,8 @@ class ModelDatasetTest(test_base.DatasetTestBase):
batch_size=batch_size))
options = dataset_ops.Options()
options.experimental_autotune = True
+ if numa_aware:
+ options.experimental_numa_aware = True
iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
index 760cd8cc4e..2ef29796ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testNumaAwareRewrite(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(["NumaMapAndBatch"])).apply(
+ batching.map_and_batch(lambda x: x * x, 10))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range(10).map(
lambda _: random_ops.random_uniform([])).batch(10)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index e556b65b7c..a97cff9fbb 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -307,6 +307,21 @@ py_test(
)
py_test(
+ name = "numa_map_and_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["numa_map_and_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "map_dataset_serialization_test",
size = "medium",
srcs = ["map_dataset_serialization_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..04aab329cd
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/numa_map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ ds = dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = True
+ return ds.with_options(options)
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
+
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
index 915d399f1b..46a9552b61 100644
--- a/tensorflow/python/data/experimental/ops/BUILD
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -122,6 +122,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index cf52f7529a..6195747671 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1410,6 +1410,8 @@ class Options(object):
"Whether to eliminate no-op transformations."),
("experimental_shuffle_and_repeat_fusion", bool,
"Whether to fuse shuffle and repeat transformations."),
+ ("experimental_numa_aware", bool,
+ "Whether to use NUMA-aware operations."),
]:
def _make_getter(name): # pylint: disable=no-self-argument
@@ -1458,6 +1460,9 @@ class Options(object):
for exp_opt in experimental_optimizations:
if getattr(self, "experimental_" + exp_opt):
result.append(exp_opt)
+
+ if getattr(self, "experimental_numa_aware"):
+ result.append("map_and_batch_numa_aware_replacement")
return result
def merge(self, options):
@@ -1485,7 +1490,7 @@ class Options(object):
"experimental_map_and_filter_fusion", "experimental_map_fusion",
"experimental_map_parallelization", "experimental_map_vectorization",
"experimental_noop_elimination",
- "experimental_shuffle_and_repeat_fusion"
+ "experimental_shuffle_and_repeat_fusion", "experimental_numa_aware",
]:
this = getattr(result, name)
that = getattr(other, name)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
index d15dccc173..22256996d3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
@@ -43,6 +43,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "experimental_numa_aware"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "experimental_shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
index d15dccc173..22256996d3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
@@ -43,6 +43,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "experimental_numa_aware"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "experimental_shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}