From 448596097801f24c14d5705fc5a8dc434c3ee1b8 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 21 Jun 2018 15:32:55 -0700 Subject: Consistently ignore the NextIteration->Merge edge; NFC I found the current behavior of ignoring the NextIteration->Merge when computing the pending counts but not ignoring it in UpdatePendingCountAndReady somewhat confusing when investigating b/110367384. PiperOrigin-RevId: 201598531 --- tensorflow/core/graph/graph_constructor.cc | 43 +++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 0967492d92..418a49b5db 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -227,6 +227,10 @@ class GraphConstructor { // already unique in the graph. string FindUniqueName(StringPiece original_name); + // Decrement pending count for users of `processed` and add the ones that now + // have all of their pending inputs satisfied to `ready_`. + void UpdatePendingCountAndReady(int processed); + // From constructor const Options opts_; const NodeDefSlice node_defs_; @@ -315,6 +319,25 @@ class GraphConstructor { std::vector back_edges_; }; +void GraphConstructor::UpdatePendingCountAndReady(int processed) { + // We didn't consider NextIteration->Merge edges when computing + // pending_counts_ so we should not have to consider it here either. + bool is_next_iteration = IsNextIteration(*node_defs_[processed]); + for (size_t i = 0; i < outputs_[processed].size(); ++i) { + const int output = outputs_[processed][i]; + bool is_next_iteration_to_merge_edge = + is_next_iteration && IsMerge(*node_defs_[output]); + if (!is_next_iteration_to_merge_edge) { + int* current_pending_count = &pending_count_[output]; + CHECK_GT(*current_pending_count, 0); + (*current_pending_count)--; + if (*current_pending_count == 0) { + ready_.insert(output); + } + } + } +} + // This could be expensive but we don't expect to call it often, if at all (only // if there are multiple nodes in g_ with the same name) bool NodeNameInValues(const std::map& input_map, @@ -881,22 +904,6 @@ Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, return Status::OK(); } -namespace { - -void UpdatePendingCountAndReady( - const std::vector>& outputs, int o, - std::vector* pending_count, std::set* ready) { - for (size_t i = 0; i < outputs[o].size(); ++i) { - const int output = outputs[o][i]; - (*pending_count)[output]--; - if ((*pending_count)[output] == 0) { - ready->insert(output); - } - } -} - -} // anonymous namespace - Status GraphConstructor::Convert() { // Import functions before adding nodes, since imported nodes may refer to // functions @@ -938,7 +945,7 @@ Status GraphConstructor::Convert() { IsNodeFullyMapped(original_node_def, &is_node_mapped)); if (is_node_mapped) { // Skip this node after updating pending_count_ for outputs - UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_); + UpdatePendingCountAndReady(o); continue; } } @@ -1031,7 +1038,7 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(ValidateShape(node)); // Update pending_count_ for outputs. - UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_); + UpdatePendingCountAndReady(o); } if (processed < node_defs_.size()) { -- cgit v1.2.3