aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-10-04 10:17:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 10:27:54 -0700
commit8622f05a62948d8966be8962a6a33e0a8b5a116d (patch)
treeca5ddeea53fbdba4e1152c28e288fe1185cf7309 /tensorflow/compiler
parent100714d9e5eb723525eb54142769f9bd8eec5edd (diff)
Don't CHECK-fail on malformed graphs in deadness analysis
Instead return a friendlier failed Status from the following two methods which used to CHECK-fail before: GetIncomingPreds, FindUniqueBackedge. While at it, also rename GetIncomingPreds to GetInputPreds to be consistent with the variable names. PiperOrigin-RevId: 215758757
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc77
1 files changed, 55 insertions, 22 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index e0b9932d80..b7ae7fbeb3 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -579,7 +580,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
- std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
+ Status GetInputPreds(Node* n, EdgeKind edge_kind,
+ std::vector<Predicate*>* result);
// Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
// bit of `should_revisit` if `pred` is different from the current predicate
@@ -625,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) {
return TensorId(e->src()->name(), e->src_output());
}
-std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
- Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
- std::vector<Predicate*> incoming_preds;
+Status DeadnessAnalysisImpl::GetInputPreds(
+ Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
+ std::vector<Predicate*>* result) {
+ result->clear();
for (const Edge* in_edge : n->in_edges()) {
bool should_process =
edge_kind == EdgeKind::kDataAndControl ||
@@ -636,17 +639,27 @@ std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
if (should_process) {
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
- CHECK(it != predicate_map_.end()) << n->name();
- incoming_preds.push_back(it->second);
+ if (it == predicate_map_.end()) {
+ GraphCycles graph_cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
+
+ // If we didn't return with an error above then the graph is probably
+ // fine and we have a bug in deadness analysis.
+ return errors::Internal("Could not find input ", in_edge->DebugString(),
+ " to ", n->name(),
+ " when visiting the graph in post-order. Most "
+ "likely indicates a bug in deadness analysis.");
+ }
+ result->push_back(it->second);
}
}
- return incoming_preds;
+ return Status::OK();
}
Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
std::vector<bool>* should_revisit) {
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
const Edge* pred_edge;
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
@@ -675,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
}
namespace {
-const Edge* FindUniqueBackedge(Node* merge) {
+Status CreateMultipleNextIterationInputsError(Node* merge) {
+ std::vector<string> backedges;
+ for (const Edge* backedge : merge->in_edges()) {
+ if (backedge->src()->IsNextIteration()) {
+ backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
+ }
+ }
+ return errors::InvalidArgument(
+ "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge),
+ ": \n", absl::StrJoin(backedges, "\n"),
+ "\nMerge nodes can have at most one incoming NextIteration edge.");
+}
+
+Status FindUniqueBackedge(Node* merge, const Edge** result) {
+ *result = nullptr;
CHECK(merge->IsMerge());
- const Edge* result = nullptr;
for (const Edge* e : merge->in_edges()) {
if (e->src()->IsNextIteration()) {
- CHECK_EQ(result, nullptr)
- << "Multiple backedges to " << merge->DebugString();
- result = e;
+ if (*result != nullptr) {
+ return CreateMultipleNextIterationInputsError(merge);
+ }
+ *result = e;
}
}
- return result;
+ return Status::OK();
}
// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
@@ -764,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
return Status::OK();
}
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
+
// We're visiting this merge for the first time and it is a acyclic merge.
- Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
- GetIncomingPreds(n, EdgeKind::kDataOnly));
+ Predicate* input_data_pred =
+ predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit);
return Status::OK();
@@ -777,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// of an unvisited backedge. Try to pattern match the predicate expression
// for that backedge (which should be visited now) into an and recurrence
// for the merge node.
- if (const Edge* unique_backedge = FindUniqueBackedge(n)) {
+ const Edge* unique_backedge;
+ TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
+ if (unique_backedge) {
if (Predicate* step = DeduceStepPredicate(
&predicate_factory_, it->second,
predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
@@ -808,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
std::vector<bool>* should_revisit) {
// In addition to being alive or dead based on the inputs, a _Recv can also
// acquire a dead signal from a _Send.
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
SetPredicate(n, {0, Graph::kControlSlot},
@@ -821,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
std::vector<bool>* should_revisit) {
// Generally nodes are alive iff all their inputs are alive.
- Predicate* pred = predicate_factory_.MakeAndPredicate(
- GetIncomingPreds(n, EdgeKind::kDataAndControl));
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
+ Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
SetPredicate(n, output_idx, pred, should_revisit);
}