aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h6
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc62
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h48
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc112
6 files changed, 279 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index ee7c14e3ab..1c553044a8 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -415,6 +415,40 @@ tf_cc_test(
)
cc_library(
+ name = "map_and_batch_numa_aware_replacement",
+ srcs = ["map_and_batch_numa_aware_replacement.cc"],
+ hdrs = ["map_and_batch_numa_aware_replacement.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_and_batch_numa_aware_replacement_test",
+ srcs = ["map_and_batch_numa_aware_replacement_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":map_and_batch_numa_aware_replacement",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
name = "noop_elimination",
srcs = ["noop_elimination.cc"],
hdrs = [
@@ -490,6 +524,7 @@ cc_library(
":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
+ ":map_and_batch_numa_aware_replacement",
":map_and_filter_fusion",
":map_fusion",
":map_parallelization",
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index b2eec7220e..1f03c6515c 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
@@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
{"output_types", gtl::ArraySlice<TensorShape>{}}});
}
+NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece batch_size_node_name,
+ StringPiece num_parallel_calls_node_name,
+ StringPiece drop_remainder_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapAndBatchDatasetV2",
+ {string(input_node_name), "", string(batch_size_node_name),
+ string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
} // end namespace graph_tests_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
index ca0fde997d..f7891d5e1f 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -29,6 +29,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
StringPiece function_name = "IsZero");
+NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece batch_size_node_name,
+ StringPiece num_parallel_calls_node_name,
+ StringPiece drop_remainder_node_name,
+ StringPiece function_name = "XTimesTwo");
+
} // end namespace graph_tests_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc
new file mode 100644
index 0000000000..452089eb67
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeNumaAware(const NodeDef& node, MutableGraphView* graph) {
+ NodeDef numa_aware_node = node;
+ graph_utils::SetUniqueGraphNodeName("map_and_batch_numa_aware",
+ graph->GetGraph(), &numa_aware_node);
+ numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
+ return numa_aware_node;
+}
+
+} // namespace
+
+Status MapAndBatchNumaAwareReplacement::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "MapAndBatchDatasetV2") continue;
+
+ auto* numa_node = graph.AddNode(MakeNumaAware(node, &graph));
+ graph.ReplaceInput(node, *numa_node);
+ nodes_to_delete.insert(node.name());
+ }
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchNumaAwareReplacement,
+ "map_and_batch_numa_aware_replacement");
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h
new file mode 100644
index 0000000000..3b2acd288b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class MapAndBatchNumaAwareReplacement : public CustomGraphOptimizer {
+ public:
+ MapAndBatchNumaAwareReplacement() = default;
+ ~MapAndBatchNumaAwareReplacement() override = default;
+
+ string name() const override {
+ return "map_and_batch_numa_aware_replacement";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc
new file mode 100644
index 0000000000..3c5c61d1c2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MapAndBatchNumaAwareReplacementTest, ReplaceSimple) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {
+ NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
+ NDef("num_parallel_calls", "Const", {},
+ {{"value", 5}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", 0}, {"dtype", DT_BOOL}}),
+ graph_tests_utils::MakeMapAndBatchNode(
+ "map_and_batch", "range", "batch_size", "num_parallel_calls",
+ "drop_remainder"),
+ },
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapAndBatchNumaAwareReplacement optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output));
+}
+
+TEST(MapAndBatchNumaAawareReplacementTest, ReplaceWithExtraChild) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {
+ NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
+ NDef("num_parallel_calls", "Const", {},
+ {{"value", 5}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", 0}, {"dtype", DT_BOOL}}),
+ graph_tests_utils::MakeMapAndBatchNode(
+ "map_and_batch", "range", "batch_size", "num_parallel_calls",
+ "drop_remainder"),
+ NDef("cache", "CacheDataset", {"map_and_batch"}, {}),
+ },
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapAndBatchNumaAwareReplacement optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
+
+ int numa_map_and_batch_component_id = graph_utils::FindGraphNodeWithOp(
+ "ExperimentalNumaMapAndBatchDataset", output);
+ auto& numa_map_and_batch_component =
+ output.node(numa_map_and_batch_component_id);
+ EXPECT_EQ(numa_map_and_batch_component.input(0), "range");
+
+ int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
+ auto& cache_node = output.node(cache_id);
+ EXPECT_EQ(cache_node.input(0), numa_map_and_batch_component.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow