From 8622f05a62948d8966be8962a6a33e0a8b5a116d Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 4 Oct 2018 10:17:02 -0700 Subject: Don't CHECK-fail on malformed graphs in deadness analysis Instead return a friendlier failed Status from the following two methods which used to CHECK-fail before: GetIncomingPreds, FindUniqueBackedge. While at it, also rename GetIncomingPreds to GetInputPreds to be consistent with the variable names. PiperOrigin-RevId: 215758757 --- tensorflow/compiler/jit/deadness_analysis.cc | 77 ++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 22 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index e0b9932d80..b7ae7fbeb3 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/hash/hash.h" @@ -579,7 +580,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -625,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -std::vector DeadnessAnalysisImpl::GetIncomingPreds( - Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { - std::vector incoming_preds; +Status DeadnessAnalysisImpl::GetInputPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, + std::vector* result) { + result->clear(); for (const Edge* in_edge : n->in_edges()) { bool should_process = edge_kind == EdgeKind::kDataAndControl || @@ -636,17 +639,27 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()) << n->name(); - incoming_preds.push_back(it->second); + if (it == predicate_map_.end()) { + GraphCycles graph_cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + + // If we didn't return with an error above then the graph is probably + // fine and we have a bug in deadness analysis. + return errors::Internal("Could not find input ", in_edge->DebugString(), + " to ", n->name(), + " when visiting the graph in post-order. Most " + "likely indicates a bug in deadness analysis."); + } + result->push_back(it->second); } } - return incoming_preds; + return Status::OK(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, std::vector* should_revisit) { - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( @@ -675,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -const Edge* FindUniqueBackedge(Node* merge) { +Status CreateMultipleNextIterationInputsError(Node* merge) { + std::vector backedges; + for (const Edge* backedge : merge->in_edges()) { + if (backedge->src()->IsNextIteration()) { + backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); + } + } + return errors::InvalidArgument( + "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), + ": \n", absl::StrJoin(backedges, "\n"), + "\nMerge nodes can have at most one incoming NextIteration edge."); +} + +Status FindUniqueBackedge(Node* merge, const Edge** result) { + *result = nullptr; CHECK(merge->IsMerge()); - const Edge* result = nullptr; for (const Edge* e : merge->in_edges()) { if (e->src()->IsNextIteration()) { - CHECK_EQ(result, nullptr) - << "Multiple backedges to " << merge->DebugString(); - result = e; + if (*result != nullptr) { + return CreateMultipleNextIterationInputsError(merge); + } + *result = e; } } - return result; + return Status::OK(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -764,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return Status::OK(); } + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); + // We're visiting this merge for the first time and it is a acyclic merge. - Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + Predicate* input_data_pred = + predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -777,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // of an unvisited backedge. Try to pattern match the predicate expression // for that backedge (which should be visited now) into an and recurrence // for the merge node. - if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + const Edge* unique_backedge; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge)); + if (unique_backedge) { if (Predicate* step = DeduceStepPredicate( &predicate_factory_, it->second, predicate_map_[InputEdgeToTensorId(unique_backedge)])) { @@ -808,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); SetPredicate(n, {0, Graph::kControlSlot}, @@ -821,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, Status DeadnessAnalysisImpl::HandleGeneric(Node* n, std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. - Predicate* pred = predicate_factory_.MakeAndPredicate( - GetIncomingPreds(n, EdgeKind::kDataAndControl)); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); + Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { SetPredicate(n, output_idx, pred, should_revisit); } -- cgit v1.2.3