aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc75
1 files changed, 37 insertions, 38 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index 0df73b33ed..7c7161c5b2 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
@@ -38,63 +38,62 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output) {
*output = item.graph;
- GraphView graph(output);
+ MutableGraphView graph(output);
std::set<string> nodes_to_delete;
- for (const NodeDef& node : item.graph.node()) {
- if (node.op() != "RepeatDataset") {
- continue;
- }
- // Use a more descriptive variable name now that we know the node type.
- const NodeDef repeat_node(node);
- GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
- if (node2->op() != "ShuffleDataset") {
- continue;
- }
-
- NodeDef* new_node = output->add_node();
- new_node->set_op(kFusedOpName);
- graph_utils::SetUniqueName(kFusedOpName, output, new_node);
-
- // Use a more descriptive variable name now that we know the node type.
- NodeDef* shuffle_node = node2;
+ auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node,
+ const NodeDef& repeat_node) {
+ NodeDef new_node;
+ new_node.set_op(kFusedOpName);
+ graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node);
// Set the `input` input argument.
- new_node->add_input(shuffle_node->input(0));
+ new_node.add_input(shuffle_node.input(0));
// Set the `buffer_size` input argument.
- new_node->add_input(shuffle_node->input(1));
+ new_node.add_input(shuffle_node.input(1));
// Set the `seed` input argument.
- new_node->add_input(shuffle_node->input(2));
+ new_node.add_input(shuffle_node.input(2));
// Set the `seed2` input argument.
- new_node->add_input(shuffle_node->input(3));
+ new_node.add_input(shuffle_node.input(3));
// Set the `count` input argument.
- new_node->add_input(repeat_node.input(1));
+ new_node.add_input(repeat_node.input(1));
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node->mutable_attr())[key] = repeat_node.attr().at(key);
+ (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
}
+ return new_node;
+ };
- // Mark the `Shuffle` and `Repeat` nodes for removal.
- nodes_to_delete.insert(shuffle_node->name());
- nodes_to_delete.insert(repeat_node.name());
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "RepeatDataset") {
+ continue;
+ }
- // Update the input of the outputs of the `Repeat` node to use
- // `ShuffleAndRepeat`.
- GraphView::OutputPort output_port =
- graph.GetOutputPort(repeat_node.name(), 0);
- auto fanout = graph.GetFanout(output_port);
- for (auto it = fanout.begin(); it != fanout.end(); ++it) {
- NodeDef* node = it->node;
- node->set_input(0, new_node->name());
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef& repeat_node = node;
+ GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
+ NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ if (node2->op() != "ShuffleDataset") {
+ continue;
}
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef& shuffle_node = *node2;
+
+ NodeDef* shuffle_and_repeat_node =
+ graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
+ graph.ReplaceInput(repeat_node, *shuffle_and_repeat_node);
+
+ // Mark the `Shuffle` and `Repeat` nodes for removal.
+ nodes_to_delete.insert(shuffle_node.name());
+ nodes_to_delete.insert(repeat_node.name());
}
- TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+
+ graph.DeleteNodes(nodes_to_delete);
return Status::OK();
}