diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc | 75 |
1 files changed, 37 insertions, 38 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index 0df73b33ed..7c7161c5b2 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" @@ -38,63 +38,62 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { *output = item.graph; - GraphView graph(output); + MutableGraphView graph(output); std::set<string> nodes_to_delete; - for (const NodeDef& node : item.graph.node()) { - if (node.op() != "RepeatDataset") { - continue; - } - // Use a more descriptive variable name now that we know the node type. - const NodeDef repeat_node(node); - GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0); - NodeDef* node2 = graph.GetRegularFanin(input_port).node; - if (node2->op() != "ShuffleDataset") { - continue; - } - - NodeDef* new_node = output->add_node(); - new_node->set_op(kFusedOpName); - graph_utils::SetUniqueName(kFusedOpName, output, new_node); - - // Use a more descriptive variable name now that we know the node type. - NodeDef* shuffle_node = node2; + auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node, + const NodeDef& repeat_node) { + NodeDef new_node; + new_node.set_op(kFusedOpName); + graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node); // Set the `input` input argument. - new_node->add_input(shuffle_node->input(0)); + new_node.add_input(shuffle_node.input(0)); // Set the `buffer_size` input argument. - new_node->add_input(shuffle_node->input(1)); + new_node.add_input(shuffle_node.input(1)); // Set the `seed` input argument. - new_node->add_input(shuffle_node->input(2)); + new_node.add_input(shuffle_node.input(2)); // Set the `seed2` input argument. - new_node->add_input(shuffle_node->input(3)); + new_node.add_input(shuffle_node.input(3)); // Set the `count` input argument. - new_node->add_input(repeat_node.input(1)); + new_node.add_input(repeat_node.input(1)); // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node->mutable_attr())[key] = repeat_node.attr().at(key); + (*new_node.mutable_attr())[key] = repeat_node.attr().at(key); } + return new_node; + }; - // Mark the `Shuffle` and `Repeat` nodes for removal. - nodes_to_delete.insert(shuffle_node->name()); - nodes_to_delete.insert(repeat_node.name()); + for (const NodeDef& node : item.graph.node()) { + if (node.op() != "RepeatDataset") { + continue; + } - // Update the input of the outputs of the `Repeat` node to use - // `ShuffleAndRepeat`. - GraphView::OutputPort output_port = - graph.GetOutputPort(repeat_node.name(), 0); - auto fanout = graph.GetFanout(output_port); - for (auto it = fanout.begin(); it != fanout.end(); ++it) { - NodeDef* node = it->node; - node->set_input(0, new_node->name()); + // Use a more descriptive variable name now that we know the node type. + const NodeDef& repeat_node = node; + GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0); + NodeDef* node2 = graph.GetRegularFanin(input_port).node; + if (node2->op() != "ShuffleDataset") { + continue; } + // Use a more descriptive variable name now that we know the node type. + const NodeDef& shuffle_node = *node2; + + NodeDef* shuffle_and_repeat_node = + graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node)); + graph.ReplaceInput(repeat_node, *shuffle_and_repeat_node); + + // Mark the `Shuffle` and `Repeat` nodes for removal. + nodes_to_delete.insert(shuffle_node.name()); + nodes_to_delete.insert(repeat_node.name()); } - TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output)); + + graph.DeleteNodes(nodes_to_delete); return Status::OK(); } |