aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 3ce238a30a..63945b8b9e 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -32,9 +32,8 @@ namespace {
constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
-NodeDef make_map_and_batch_node(const NodeDef& map_node,
- const NodeDef& batch_node,
- MutableGraphView* graph) {
+NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
+ MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
@@ -104,8 +103,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& batch_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
@@ -113,7 +112,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* map_node = node2;
auto* new_node =
- graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph));
+ graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
graph.ReplaceInput(batch_node, *new_node);
// Mark the `Map` and `Batch` nodes for removal.