diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-20 16:05:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-20 16:08:25 -0700 |
commit | a0071844d0af47f22ab512363b56383acf762dff (patch) | |
tree | 1a35ee263240d9dbffa96c22f854d8745651b093 | |
parent | c015a45646029f8c116028505f2da9e023b5c2b7 (diff) |
Remove protected data members from GraphOptimizerStage.
PiperOrigin-RevId: 193737654
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/graph_optimizer_stage.h | 5 |
2 files changed, 31 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 232132e1e8..ed199c1ac8 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -294,8 +294,8 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { for (int i = src->input_size() - 1; i >= 0; --i) { if (IsControlInput(src->input(i))) { *target_node->add_input() = src->input(i); - ctx_.node_map->AddOutput(NodeName(src->input(i)), - target_node->name()); + ctx().node_map->AddOutput(NodeName(src->input(i)), + target_node->name()); } else { break; } @@ -442,7 +442,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { // TODO(ezhulenev): move to GraphOptimizerStage? bool DrivesControlDependency(const NodeDef& node) const { int position; - for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) { + for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { for (int i = 0; i < output->input_size(); ++i) { auto input = output->input(i); string name = ParseNodeName(input, &position); @@ -476,8 +476,8 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { } bool IsInPreserveSet(const NodeDef& node) const { - return ctx_.nodes_to_preserve->find(node.name()) != - ctx_.nodes_to_preserve->end(); + return ctx().nodes_to_preserve->find(node.name()) != + ctx().nodes_to_preserve->end(); } bool IsAlreadyOptimized(const NodeDef& node) const { @@ -546,7 +546,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { // with a single output data consumer (presumably if we reach this node from // previously absorbed or a root node, it means that this node is not used // as an input to any other op, outside of the group) - if (NumNonControlDataOutputs(node, *ctx_.node_map) != 1) { + if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) { return false; } // All input shapes must be broadcastable to the node shape @@ -685,7 +685,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { (*node->mutable_attr())["N"].set_i(inputs.size()); for (const auto& inputAndShape : inputs) { - ctx_.node_map->AddOutput(inputAndShape.input, node_name); + ctx().node_map->AddOutput(inputAndShape.input, node_name); node->add_input(inputAndShape.input); } @@ -707,8 +707,8 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { node->set_device(root_node.device()); (*node->mutable_attr())["T"].set_type(dtype); - ctx_.node_map->AddOutput(left.input, node_name); - ctx_.node_map->AddOutput(right.input, node_name); + ctx().node_map->AddOutput(left.input, node_name); + ctx().node_map->AddOutput(right.input, node_name); node->add_input(left.input); node->add_input(right.input); @@ -784,20 +784,20 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { new_outer_node->set_input(1, new_add_node->name()); } - ctx_.node_map->AddOutput(common_factor, new_outer_node->name()); - ctx_.node_map->AddOutput(new_add_node->name(), new_outer_node->name()); + ctx().node_map->AddOutput(common_factor, new_outer_node->name()); + ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name()); // Hoist non-shared factors up into the new AddN node. for (int i = 0; i < unique_factors.size(); ++i) { const string& unique_factor_i = unique_factors[i]; new_add_node->set_input(i, unique_factor_i); - ctx_.node_map->AddOutput(unique_factor_i, new_add_node->name()); + ctx().node_map->AddOutput(unique_factor_i, new_add_node->name()); } // Add control deps on add node for (const string& ctrl_dep : ctrl_deps) { *new_add_node->add_input() = ctrl_dep; - ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name()); + ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name()); } // optimize new inner aggregation node @@ -931,8 +931,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { // if graph rewrite happens in multiple passes without graph pruning between // them, it's possible that rewritten node already exists in a graph return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() || - ctx_.node_map->NodeExists(OuterNodeName(node, false)) || - ctx_.node_map->NodeExists(OuterNodeName(node, true)); + ctx().node_map->NodeExists(OuterNodeName(node, false)) || + ctx().node_map->NodeExists(OuterNodeName(node, true)); } // keep names of the nodes that were optimized by this stage @@ -996,7 +996,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { } // Optimized nodes updated in place, and that would break the graph, if the // node has multiple output consumers - if (NumNonControlOutputs(node, *ctx_.node_map) != 1) { + if (NumNonControlOutputs(node, *ctx().node_map) != 1) { return false; } // All input shapes must be broadcastable to the node shape @@ -1120,13 +1120,13 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { node->set_input(0, input_0); node->set_input(1, input_1); // Invalidate node properties (shape) - ctx_.graph_properties->ClearOutputProperties(node->name()); - ctx_.graph_properties->ClearInputProperties(node->name()); + ctx().graph_properties->ClearOutputProperties(node->name()); + ctx().graph_properties->ClearInputProperties(node->name()); // Update the node map - ctx_.node_map->RemoveOutput(NodeName(old_input_0), node->name()); - ctx_.node_map->RemoveOutput(NodeName(old_input_1), node->name()); - ctx_.node_map->AddOutput(NodeName(input_0), node->name()); - ctx_.node_map->AddOutput(NodeName(input_1), node->name()); + ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name()); + ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name()); + ctx().node_map->AddOutput(NodeName(input_0), node->name()); + ctx().node_map->AddOutput(NodeName(input_1), node->name()); // Add updated node to optimization queue AddToOptimizationQueue(node); } @@ -1257,8 +1257,8 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) bitcast->set_input(0, operand->input(0)); SetSourceDataType(GetSourceDataType(*operand), bitcast); - ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0), - operand->input(0)); + ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0), + operand->input(0)); AddToOptimizationQueue(bitcast); *simplified_node_name = bitcast->name(); } @@ -1313,14 +1313,14 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { node->mutable_input()->SwapElements(0, 1); node->set_input(1, x->input(0)); node->add_input(AsControlDependency(x->name())); - ctx_.node_map->AddOutput(NodeName(x->input(0)), node_name); + ctx().node_map->AddOutput(NodeName(x->input(0)), node_name); updated = true; } else if (IsNeg(*y)) { // a + (-b) = a - b node->set_op("Sub"); node->set_input(1, y->input(0)); node->add_input(AsControlDependency(y->name())); - ctx_.node_map->AddOutput(NodeName(y->input(0)), node_name); + ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); updated = true; } } else if (IsSub(*node)) { @@ -1329,7 +1329,7 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { node->set_op("Add"); node->set_input(1, y->input(0)); node->add_input(AsControlDependency(y->name())); - ctx_.node_map->AddOutput(NodeName(y->input(0)), node_name); + ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); updated = true; } } diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index ed398525f3..089cad36e9 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -182,7 +182,10 @@ class GraphOptimizerStage { return ::tensorflow::grappler::AddEmptyNode(ctx_, name); } - protected: // Data members + protected: + const GraphOptimizerContext& ctx() const { return ctx_; } + + private: // Data members const string optimizer_name_; const string stage_name_; const GraphOptimizerContext ctx_; |