diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 3ce238a30a..63945b8b9e 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -32,9 +32,8 @@ namespace { constexpr char kFusedOpName[] = "MapAndBatchDatasetV2"; -NodeDef make_map_and_batch_node(const NodeDef& map_node, - const NodeDef& batch_node, - MutableGraphView* graph) { +NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, + MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kFusedOpName); graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(), @@ -104,8 +103,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // Use a more descriptive variable name now that we know the node type. const NodeDef& batch_node = node; - GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0); - NodeDef* node2 = graph.GetRegularFanin(input_port).node; + NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph); + if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") { continue; } @@ -113,7 +112,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* map_node = node2; auto* new_node = - graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph)); + graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph)); graph.ReplaceInput(batch_node, *new_node); // Mark the `Map` and `Batch` nodes for removal. |