aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-20 12:33:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 12:38:18 -0700
commit5e1cd6a15f6b65ccb5660714274368b71486c7f6 (patch)
treec3cbc5b5c3dda6ba176546bcbb09091733a86f8c
parentd0f50579bfed36597d3c33f83e02009d7993ea41 (diff)
Minor NFC cleanups to deadness analysis.
- Expand SetPred to SetPredicate - Use stable ordering during RPO PiperOrigin-RevId: 209465102
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc45
1 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 309aeffc18..0ca0f949dc 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -508,8 +508,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
// Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
// bit of `should_revisit` if `pred` is different from the current predicate
// for the `output_idx` output of `n`.
- void SetPred(Node* n, int output_idx, Predicate* pred,
- std::vector<bool>* should_revisit) {
+ void SetPredicate(Node* n, int output_idx, Predicate* pred,
+ std::vector<bool>* should_revisit) {
auto insert_result =
predicate_map_.insert({TensorId(n->name(), output_idx), pred});
if (!insert_result.second && insert_result.first->second != pred) {
@@ -526,10 +526,10 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
}
}
- void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
- std::vector<bool>* should_revisit) {
+ void SetPredicate(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
+ std::vector<bool>* should_revisit) {
for (int output_idx : output_idxs) {
- SetPred(n, output_idx, pred, should_revisit);
+ SetPredicate(n, output_idx, pred, should_revisit);
}
}
@@ -580,19 +580,20 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
- SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
- should_revisit);
+ SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
- SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
- should_revisit);
+ SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
// Control is alive iff all inputs are alive.
- SetPred(n, Graph::kControlSlot,
- predicate_factory_.MakeAndPredicate(input_preds), should_revisit);
+ SetPredicate(n, Graph::kControlSlot,
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
@@ -682,14 +683,16 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// backedge.
Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false);
- SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
return Status::OK();
}
// We're visiting this merge for the first time and it is a acyclic merge.
Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
GetIncomingPreds(n, EdgeKind::kDataOnly));
- SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
return Status::OK();
}
@@ -717,7 +720,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
Predicate* and_rec =
predicate_factory_.MakeAndRecurrencePredicate(start, step);
- SetPred(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
return Status::OK();
}
}
@@ -733,8 +736,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
GetIncomingPreds(n, EdgeKind::kDataAndControl);
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
- SetPred(n, {0, Graph::kControlSlot},
- predicate_factory_.MakeAndPredicate(input_preds), should_revisit);
+ SetPredicate(n, {0, Graph::kControlSlot},
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
@@ -744,9 +748,9 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
Predicate* pred = predicate_factory_.MakeAndPredicate(
GetIncomingPreds(n, EdgeKind::kDataAndControl));
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
- SetPred(n, output_idx, pred, should_revisit);
+ SetPredicate(n, output_idx, pred, should_revisit);
}
- SetPred(n, Graph::kControlSlot, pred, should_revisit);
+ SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
return Status::OK();
}
@@ -757,7 +761,8 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
} else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
} else if (n->IsControlTrigger()) {
- SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), nullptr);
+ SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
+ nullptr);
} else if (n->IsRecv() || n->IsHostRecv()) {
TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
} else if (n->IsNextIteration()) {
@@ -770,7 +775,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
- GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
+ GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});