diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-20 12:33:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 12:38:18 -0700 |
commit | 5e1cd6a15f6b65ccb5660714274368b71486c7f6 (patch) | |
tree | c3cbc5b5c3dda6ba176546bcbb09091733a86f8c | |
parent | d0f50579bfed36597d3c33f83e02009d7993ea41 (diff) |
Minor NFC cleanups to deadness analysis.
- Expand SetPred to SetPredicate
- Use stable ordering during RPO
PiperOrigin-RevId: 209465102
-rw-r--r-- | tensorflow/compiler/jit/deadness_analysis.cc | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 309aeffc18..0ca0f949dc 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -508,8 +508,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate // for the `output_idx` output of `n`. - void SetPred(Node* n, int output_idx, Predicate* pred, - std::vector<bool>* should_revisit) { + void SetPredicate(Node* n, int output_idx, Predicate* pred, + std::vector<bool>* should_revisit) { auto insert_result = predicate_map_.insert({TensorId(n->name(), output_idx), pred}); if (!insert_result.second && insert_result.first->second != pred) { @@ -526,10 +526,10 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } } - void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred, - std::vector<bool>* should_revisit) { + void SetPredicate(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred, + std::vector<bool>* should_revisit) { for (int output_idx : output_idxs) { - SetPred(n, output_idx, pred, should_revisit); + SetPredicate(n, output_idx, pred, should_revisit); } } @@ -580,19 +580,20 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, // Output 0 is alive iff all inputs are alive and the condition is false. input_preds.push_back(false_switch); - SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds), - should_revisit); + SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); // Output 1 is alive iff all inputs are alive and the condition is true. input_preds.push_back(true_switch); - SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds), - should_revisit); + SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); // Control is alive iff all inputs are alive. - SetPred(n, Graph::kControlSlot, - predicate_factory_.MakeAndPredicate(input_preds), should_revisit); + SetPredicate(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } @@ -682,14 +683,16 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // backedge. Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false); - SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); return Status::OK(); } // We're visiting this merge for the first time and it is a acyclic merge. Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( GetIncomingPreds(n, EdgeKind::kDataOnly)); - SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); return Status::OK(); } @@ -717,7 +720,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, predicate_factory_.MakeOrPredicate(non_recurrent_inputs); Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(start, step); - SetPred(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); + SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); return Status::OK(); } } @@ -733,8 +736,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, GetIncomingPreds(n, EdgeKind::kDataAndControl); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); - SetPred(n, {0, Graph::kControlSlot}, - predicate_factory_.MakeAndPredicate(input_preds), should_revisit); + SetPredicate(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } @@ -744,9 +748,9 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n, Predicate* pred = predicate_factory_.MakeAndPredicate( GetIncomingPreds(n, EdgeKind::kDataAndControl)); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { - SetPred(n, output_idx, pred, should_revisit); + SetPredicate(n, output_idx, pred, should_revisit); } - SetPred(n, Graph::kControlSlot, pred, should_revisit); + SetPredicate(n, Graph::kControlSlot, pred, should_revisit); return Status::OK(); } @@ -757,7 +761,8 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, } else if (n->IsMerge()) { TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); } else if (n->IsControlTrigger()) { - SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), nullptr); + SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), + nullptr); } else if (n->IsRecv() || n->IsHostRecv()) { TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit)); } else if (n->IsNextIteration()) { @@ -770,7 +775,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, Status DeadnessAnalysisImpl::Populate() { std::vector<Node*> rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/[](const Edge& edge) { return !edge.src()->IsNextIteration(); }); |