diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/loop_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/loop_optimizer.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index f78036d78c..bd0d94b83f 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -413,7 +414,7 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { frame_children_[frame_ids[0]].insert(frame_ids[1]); frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2]; } - if (!frame_ids.empty()) { + if (frame_ids.size() >= 1) { frame_children_.insert(std::make_pair(frame_ids.back(), empty_set_)); if (node->op() == "LoopCond") { if (loop_cond_.count(frame_ids.back())) { @@ -432,7 +433,7 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { } for (auto it = frame_children_.begin(); it != frame_children_.end(); ++it) { - if (it->second.empty()) { + if (it->second.size() == 0) { worklist.push_back(it->first); } } @@ -445,7 +446,7 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { if (parent_it != frame_parent_.end()) { int parent_id = parent_it->second; frame_children_[parent_id].erase(frame_id); - if (frame_children_[parent_id].empty()) { + if (frame_children_[parent_id].size() == 0) { worklist.push_back(parent_id); } } @@ -468,6 +469,7 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph)); if (opt_level_ == RewriterConfig::AGGRESSIVE) { |