aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc136
1 files changed, 66 insertions, 70 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 1e8cbb9784..3ce238a30a 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
@@ -32,12 +32,70 @@ namespace {
constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
+NodeDef make_map_and_batch_node(const NodeDef& map_node,
+ const NodeDef& batch_node,
+ MutableGraphView* graph) {
+ NodeDef new_node;
+ new_node.set_op(kFusedOpName);
+ graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
+ &new_node);
+
+ // Set the `input` input argument.
+ new_node.add_input(map_node.input(0));
+
+ // Set the `other_arguments` input arguments.
+ int num_other_args;
+ if (map_node.op() == "ParallelMapDataset") {
+ num_other_args = map_node.input_size() - 2;
+ } else {
+ num_other_args = map_node.input_size() - 1;
+ }
+ for (int i = 0; i < num_other_args; i++) {
+ new_node.add_input(map_node.input(i + 1));
+ }
+
+ // Set the `batch_size` input argument.
+ new_node.add_input(batch_node.input(1));
+
+ // Set the `num_parallel_calls` input argument.
+ if (map_node.op() == "ParallelMapDataset") {
+ // The type of the `num_parallel_calls` argument in ParallelMapDataset
+ // and MapAndBatchDataset is different (int32 and int64 respectively)
+ // so we cannot reuse the same Const node and thus create a new one.
+ NodeDef* v = graph->GetNode(map_node.input(map_node.input_size() - 1));
+ NodeDef* tmp = graph_utils::AddScalarConstNode<int64>(
+ v->attr().at("value").tensor().int_val(0), graph);
+ new_node.add_input(tmp->name());
+ } else {
+ NodeDef* tmp = graph_utils::AddScalarConstNode<int64>(1, graph);
+ new_node.add_input(tmp->name());
+ }
+
+ // Set the `drop_remainder` input argument.
+ if (batch_node.op() == "BatchDatasetV2") {
+ new_node.add_input(batch_node.input(2));
+ } else {
+ NodeDef* tmp = graph_utils::AddScalarConstNode<bool>(false, graph);
+ new_node.add_input(tmp->name());
+ }
+
+ // Set `f` and `Targuments` attributes.
+ for (auto key : {"f", "Targuments"}) {
+ (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ }
+ // Set `output_types` and `output_shapes` attributes.
+ for (auto key : {"output_shapes", "output_types"}) {
+ (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ }
+ return new_node;
+}
+
} // namespace
Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
*output = item.graph;
- GraphView graph(output);
+ MutableGraphView graph(output);
std::set<string> nodes_to_delete;
for (const NodeDef& node : item.graph.node()) {
if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
@@ -45,87 +103,25 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Use a more descriptive variable name now that we know the node type.
- const NodeDef batch_node(node);
+ const NodeDef& batch_node = node;
GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
NodeDef* node2 = graph.GetRegularFanin(input_port).node;
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
-
- NodeDef* new_node = output->add_node();
- new_node->set_op(kFusedOpName);
- graph_utils::SetUniqueName(kFusedOpName, output, new_node);
-
// Use a more descriptive variable name now that we know the node type.
NodeDef* map_node = node2;
- // Set the `input` input argument.
- new_node->add_input(map_node->input(0));
-
- // Set the `other_arguments` input arguments.
- int num_other_args;
- if (map_node->op() == "ParallelMapDataset") {
- num_other_args = map_node->input_size() - 2;
- } else {
- num_other_args = map_node->input_size() - 1;
- }
- for (int i = 0; i < num_other_args; i++) {
- new_node->add_input(map_node->input(i + 1));
- }
-
- // Set the `batch_size` input argument.
- new_node->add_input(batch_node.input(1));
-
- // Set the `num_parallel_calls` input argument.
- if (map_node->op() == "ParallelMapDataset") {
- // The type of the `num_parallel_calls` argument in ParallelMapDataset
- // and MapAndBatchDataset is different (int32 and int64 respectively)
- // so we cannot reuse the same Const node and thus create a new one.
- NodeDef* v = graph.GetNode(map_node->input(map_node->input_size() - 1));
- NodeDef* tmp;
- TF_RETURN_IF_ERROR(graph_utils::AddScalarConstNode<int64>(
- v->attr().at("value").tensor().int_val(0), output, &tmp));
- new_node->add_input(tmp->name());
- } else {
- NodeDef* tmp;
- TF_RETURN_IF_ERROR(
- graph_utils::AddScalarConstNode<int64>(1, output, &tmp));
- new_node->add_input(tmp->name());
- }
-
- // Set the `drop_remainder` input argument.
- if (batch_node.op() == "BatchDatasetV2") {
- new_node->add_input(batch_node.input(2));
- } else {
- NodeDef* tmp;
- TF_RETURN_IF_ERROR(
- graph_utils::AddScalarConstNode<bool>(false, output, &tmp));
- new_node->add_input(tmp->name());
- }
- // Set `f` and `Targuments` attributes.
- for (auto key : {"f", "Targuments"}) {
- (*new_node->mutable_attr())[key] = map_node->attr().at(key);
- }
- // Set `output_types` and `output_shapes` attributes.
- for (auto key : {"output_shapes", "output_types"}) {
- (*new_node->mutable_attr())[key] = batch_node.attr().at(key);
- }
+ auto* new_node =
+ graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph));
+ graph.ReplaceInput(batch_node, *new_node);
// Mark the `Map` and `Batch` nodes for removal.
nodes_to_delete.insert(map_node->name());
nodes_to_delete.insert(batch_node.name());
-
- // Update the input of the outputs of the `Batch` node to use
- // `MapAndBatch`.
- GraphView::OutputPort output_port =
- graph.GetOutputPort(batch_node.name(), 0);
- auto fanout = graph.GetFanout(output_port);
- for (auto it = fanout.begin(); it != fanout.end(); ++it) {
- NodeDef* node = it->node;
- node->set_input(0, new_node->name());
- }
}
- TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+
+ graph.DeleteNodes(nodes_to_delete);
return Status::OK();
}