aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-20 16:05:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 16:08:25 -0700
commita0071844d0af47f22ab512363b56383acf762dff (patch)
tree1a35ee263240d9dbffa96c22f854d8745651b093
parentc015a45646029f8c116028505f2da9e023b5c2b7 (diff)
Remove protected data members from GraphOptimizerStage.
PiperOrigin-RevId: 193737654
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h5
2 files changed, 31 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 232132e1e8..ed199c1ac8 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -294,8 +294,8 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (int i = src->input_size() - 1; i >= 0; --i) {
if (IsControlInput(src->input(i))) {
*target_node->add_input() = src->input(i);
- ctx_.node_map->AddOutput(NodeName(src->input(i)),
- target_node->name());
+ ctx().node_map->AddOutput(NodeName(src->input(i)),
+ target_node->name());
} else {
break;
}
@@ -442,7 +442,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
// TODO(ezhulenev): move to GraphOptimizerStage?
bool DrivesControlDependency(const NodeDef& node) const {
int position;
- for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
+ for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
string name = ParseNodeName(input, &position);
@@ -476,8 +476,8 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
}
bool IsInPreserveSet(const NodeDef& node) const {
- return ctx_.nodes_to_preserve->find(node.name()) !=
- ctx_.nodes_to_preserve->end();
+ return ctx().nodes_to_preserve->find(node.name()) !=
+ ctx().nodes_to_preserve->end();
}
bool IsAlreadyOptimized(const NodeDef& node) const {
@@ -546,7 +546,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
// with a single output data consumer (presumably if we reach this node from
// previously absorbed or a root node, it means that this node is not used
// as an input to any other op, outside of the group)
- if (NumNonControlDataOutputs(node, *ctx_.node_map) != 1) {
+ if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
return false;
}
// All input shapes must be broadcastable to the node shape
@@ -685,7 +685,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
(*node->mutable_attr())["N"].set_i(inputs.size());
for (const auto& inputAndShape : inputs) {
- ctx_.node_map->AddOutput(inputAndShape.input, node_name);
+ ctx().node_map->AddOutput(inputAndShape.input, node_name);
node->add_input(inputAndShape.input);
}
@@ -707,8 +707,8 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
node->set_device(root_node.device());
(*node->mutable_attr())["T"].set_type(dtype);
- ctx_.node_map->AddOutput(left.input, node_name);
- ctx_.node_map->AddOutput(right.input, node_name);
+ ctx().node_map->AddOutput(left.input, node_name);
+ ctx().node_map->AddOutput(right.input, node_name);
node->add_input(left.input);
node->add_input(right.input);
@@ -784,20 +784,20 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
new_outer_node->set_input(1, new_add_node->name());
}
- ctx_.node_map->AddOutput(common_factor, new_outer_node->name());
- ctx_.node_map->AddOutput(new_add_node->name(), new_outer_node->name());
+ ctx().node_map->AddOutput(common_factor, new_outer_node->name());
+ ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
// Hoist non-shared factors up into the new AddN node.
for (int i = 0; i < unique_factors.size(); ++i) {
const string& unique_factor_i = unique_factors[i];
new_add_node->set_input(i, unique_factor_i);
- ctx_.node_map->AddOutput(unique_factor_i, new_add_node->name());
+ ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
}
// Add control deps on add node
for (const string& ctrl_dep : ctrl_deps) {
*new_add_node->add_input() = ctrl_dep;
- ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
+ ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
}
// optimize new inner aggregation node
@@ -931,8 +931,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
// if graph rewrite happens in multiple passes without graph pruning between
// them, it's possible that rewritten node already exists in a graph
return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
- ctx_.node_map->NodeExists(OuterNodeName(node, false)) ||
- ctx_.node_map->NodeExists(OuterNodeName(node, true));
+ ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
+ ctx().node_map->NodeExists(OuterNodeName(node, true));
}
// keep names of the nodes that were optimized by this stage
@@ -996,7 +996,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
}
// Optimized nodes updated in place, and that would break the graph, if the
// node has multiple output consumers
- if (NumNonControlOutputs(node, *ctx_.node_map) != 1) {
+ if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
return false;
}
// All input shapes must be broadcastable to the node shape
@@ -1120,13 +1120,13 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
node->set_input(0, input_0);
node->set_input(1, input_1);
// Invalidate node properties (shape)
- ctx_.graph_properties->ClearOutputProperties(node->name());
- ctx_.graph_properties->ClearInputProperties(node->name());
+ ctx().graph_properties->ClearOutputProperties(node->name());
+ ctx().graph_properties->ClearInputProperties(node->name());
// Update the node map
- ctx_.node_map->RemoveOutput(NodeName(old_input_0), node->name());
- ctx_.node_map->RemoveOutput(NodeName(old_input_1), node->name());
- ctx_.node_map->AddOutput(NodeName(input_0), node->name());
- ctx_.node_map->AddOutput(NodeName(input_1), node->name());
+ ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
+ ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
+ ctx().node_map->AddOutput(NodeName(input_0), node->name());
+ ctx().node_map->AddOutput(NodeName(input_1), node->name());
// Add updated node to optimization queue
AddToOptimizationQueue(node);
}
@@ -1257,8 +1257,8 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
// Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
bitcast->set_input(0, operand->input(0));
SetSourceDataType(GetSourceDataType(*operand), bitcast);
- ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0),
- operand->input(0));
+ ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
+ operand->input(0));
AddToOptimizationQueue(bitcast);
*simplified_node_name = bitcast->name();
}
@@ -1313,14 +1313,14 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
node->mutable_input()->SwapElements(0, 1);
node->set_input(1, x->input(0));
node->add_input(AsControlDependency(x->name()));
- ctx_.node_map->AddOutput(NodeName(x->input(0)), node_name);
+ ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
updated = true;
} else if (IsNeg(*y)) {
// a + (-b) = a - b
node->set_op("Sub");
node->set_input(1, y->input(0));
node->add_input(AsControlDependency(y->name()));
- ctx_.node_map->AddOutput(NodeName(y->input(0)), node_name);
+ ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
updated = true;
}
} else if (IsSub(*node)) {
@@ -1329,7 +1329,7 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
node->set_op("Add");
node->set_input(1, y->input(0));
node->add_input(AsControlDependency(y->name()));
- ctx_.node_map->AddOutput(NodeName(y->input(0)), node_name);
+ ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
updated = true;
}
}
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index ed398525f3..089cad36e9 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -182,7 +182,10 @@ class GraphOptimizerStage {
return ::tensorflow::grappler::AddEmptyNode(ctx_, name);
}
- protected: // Data members
+ protected:
+ const GraphOptimizerContext& ctx() const { return ctx_; }
+
+ private: // Data members
const string optimizer_name_;
const string stage_name_;
const GraphOptimizerContext ctx_;