diff options
author | 2018-04-06 16:00:41 -0700 | |
---|---|---|
committer | 2018-04-06 16:03:12 -0700 | |
commit | d8ec179569514c068284c84540826d077a30485d (patch) | |
tree | bbf637c102a9cc9c98fd334857879cf42e984798 | |
parent | d017e6f030c06d4803897a0321144254ad563165 (diff) |
Refactor LoopOptimizer:
* Put loop-invariant node motion in its own class.
* Add granular control of which passes to run.
Swap order of LINM and stack push removal.
PiperOrigin-RevId: 191953537
3 files changed, 193 insertions, 144 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index a063dc3381..28ce2c7a55 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -16,18 +16,17 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include <algorithm> +#include <deque> #include <limits> #include <unordered_map> #include <unordered_set> #include <vector> -#include <deque> #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -46,74 +45,36 @@ namespace tensorflow { namespace grappler { namespace { -std::vector<int> GetStackPushNodesToConvert( - const SimpleGraphView& graph_view, - const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) { - VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name(); - const std::unordered_set<string> op_types_to_traverse( - {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch", - "Identity", "RefIdentity"}); - std::vector<int> nodes_to_convert; - std::set<int> fanout; - graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout); - for (int fanout_idx : fanout) { - const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx); - VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name(); - if (IsStackPushOp(fanout_node)) { - nodes_to_convert.push_back(fanout_idx); - } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) || - op_types_to_traverse.find(fanout_node.op()) != - op_types_to_traverse.end()) { - continue; - } else if (!IsStackPopOp(fanout_node) || - (!graph_view.outputs(fanout_idx).empty() || - nodes_to_preserve.find(fanout_node.name()) != - nodes_to_preserve.end())) { - // The node is either a stack pop with consumers or something unexpected - // so we leave the graph alone. - nodes_to_convert.clear(); - break; - } - } - return nodes_to_convert; -} +class LoopInvariantNodeMotionOptimizer { + public: + explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph) + : optimized_graph_(optimized_graph) {} + virtual ~LoopInvariantNodeMotionOptimizer() = default; + Status Optimize(); -Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) { - const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve(); - const GraphDef& graph = item.graph; - *optimized_graph = graph; - NodeMap node_map(optimized_graph); - SimpleGraphView graph_view; - TF_RETURN_IF_ERROR(graph_view.Initialize(graph)); - for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) { - if (IsStackOp(graph.node(node_idx))) { - for (int push_node_idx : GetStackPushNodesToConvert( - graph_view, nodes_to_preserve, node_idx)) { - // We found push nodes without corresponding pops. Convert them to - // Identity passing the data through and add a control dependency from - // the op supplying the stack handle. - NodeDef* push_node = optimized_graph->mutable_node(push_node_idx); - VLOG(1) << "Converting " << push_node_idx << " : " - << push_node->DebugString(); - if (push_node->attr().count("swap_memory") != 0) { - push_node->mutable_attr()->erase("swap_memory"); - } - push_node->set_op("Identity"); - push_node->mutable_input()->SwapElements(0, 1); - const string ctrl_dep = ConstantFolding::AddControlDependency( - push_node->input(1), optimized_graph, &node_map); - push_node->set_input(1, ctrl_dep); - VLOG(1) << "After converting: " << push_node->DebugString(); - } - } - } - return Status::OK(); -} + private: + Status FindInvariantNodes(NodeDef* node); + Status RevertInvariantNodes(); + Status MoveInvariantNodes(const int frame_id); + Status HandleInvariantNode(NodeDef* node, const int num_outputs, + const int frame_id); + Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id); + Status HandleInvariantEnter(NodeDef* node, const int num_outputs); -} // namespace + GraphDef* optimized_graph_; // Not owned. + std::unique_ptr<NodeMap> node_map_; + std::map<NodeDef*, int> invariant_nodes_; + std::set<int> empty_set_; + // TODO(rmlarsen): Use vector instead of map, since frames ids are dense. + std::map<int, std::set<int>> frame_children_; + std::map<int, int> frame_parent_; + std::map<int, const NodeDef*> loop_cond_; + std::map<int, std::vector<NodeDef*>> invariant_enters_; + int new_enter_id_; +}; -Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node, - const int num_outputs) { +Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter( + NodeDef* node, const int num_outputs) { auto consumers = node_map_->GetOutputs(node->name()); std::vector<string> enter_control_inputs; string enter_input; @@ -142,8 +103,9 @@ Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node, return Status::OK(); } -Status LoopOptimizer::LINMHandleConst(NodeDef* node, - const int num_outputs, const int frame_id) { +Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node, + const int num_outputs, + const int frame_id) { NodeDef* const_node; if (num_outputs == 0) { // all successor nodes are invariant @@ -185,8 +147,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node, int parent_id = parent_it->second; auto loop_cond_it = loop_cond_.find(parent_id); if (loop_cond_it == loop_cond_.end()) { - return errors::InvalidArgument( - "Frame ", frame_id, " doesn't have a LoopCond node"); + return errors::InvalidArgument("Frame ", frame_id, + " doesn't have a LoopCond node"); } auto& loop_cond_name = loop_cond_it->second->name(); NodeDef* switch_node = nullptr; @@ -197,9 +159,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node, } } if (!switch_node) { - return errors::InvalidArgument( - "LoopCond node of Frame ", frame_id, - " doesn't connect to any Switch node"); + return errors::InvalidArgument("LoopCond node of Frame ", frame_id, + " doesn't connect to any Switch node"); } string switch_output = StrCat(switch_node->name(), ":1"); const string ctrl_dep = ConstantFolding::AddControlDependency( @@ -210,8 +171,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node, return Status::OK(); } -Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node, - const int num_outputs, const int frame_id) { +Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode( + NodeDef* node, const int num_outputs, const int frame_id) { // have to remove control inputs to the invariant node from the same frame // when moving this node out of this frame for (int i = 0; i < node->input_size(); ++i) { @@ -228,16 +189,14 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node, DataTypeVector output_types; OpRegistryInterface* op_registry = OpRegistry::Global(); const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR( - op_registry->LookUp(node->op(), &op_reg_data)); - TF_RETURN_IF_ERROR( - InOutTypesForNode(*node, op_reg_data->op_def, - &input_types, &output_types)); + TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data)); + TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types, + &output_types)); auto consumers = node_map_->GetOutputs(node->name()); string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s(); - int piterations = invariant_enters_[frame_id][0] - ->attr().at("parallel_iterations").i(); + int piterations = + invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i(); for (auto* consumer : consumers) { if (!invariant_nodes_.count(consumer)) { for (int i = 0; i < consumer->input_size(); ++i) { @@ -281,28 +240,27 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node, return Status::OK(); } -Status LoopOptimizer::MoveInvariantNodes(const int frame_id) { - for (auto iter = invariant_nodes_.begin(); - iter != invariant_nodes_.end(); ++iter) { +Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes( + const int frame_id) { + for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end(); + ++iter) { auto* invariant_node = iter->first; const int num_outputs = iter->second; if (IsEnter(*invariant_node)) { - TF_RETURN_IF_ERROR( - LINMHandleInvariantEnter(invariant_node, num_outputs)); + TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs)); } else if (IsConstant(*invariant_node)) { - TF_RETURN_IF_ERROR( - LINMHandleConst(invariant_node, num_outputs, frame_id)); + TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id)); } else { TF_RETURN_IF_ERROR( - LINMHandleInvariantNode(invariant_node, num_outputs, frame_id)); + HandleInvariantNode(invariant_node, num_outputs, frame_id)); } } return Status::OK(); } -Status LoopOptimizer::RevertInvariantNodes() { +Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() { std::deque<const NodeDef*> reverted_nodes; - for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) { + for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) { bool erased = false; const auto* node = iter->first; if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) { @@ -331,8 +289,8 @@ Status LoopOptimizer::RevertInvariantNodes() { auto* producer = node_map_->GetNode(input); auto iter = invariant_nodes_.find(producer); if (iter != invariant_nodes_.end()) { - if (IsControlInput(input) && - !IsConstant(*producer) && !IsEnter(*producer)) { + if (IsControlInput(input) && !IsConstant(*producer) && + !IsEnter(*producer)) { reverted_nodes.push_back(producer); invariant_nodes_.erase(iter); } else { @@ -357,12 +315,11 @@ Status LoopOptimizer::RevertInvariantNodes() { return Status::OK(); } -Status LoopOptimizer::FindInvariantNodes(NodeDef* node) { +Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) { auto consumers = node_map_->GetOutputs(node->name()); invariant_nodes_.insert(std::make_pair(node, consumers.size())); for (auto* consumer : consumers) { - if (invariant_nodes_.count(consumer) || - ModifiesFrameInfo(*consumer)) { + if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) { continue; } bool is_invariant = true; @@ -399,9 +356,14 @@ Status LoopOptimizer::FindInvariantNodes(NodeDef* node) { return Status::OK(); } -Status LoopOptimizer::LoopInvariantNodeMotion() { +Status LoopInvariantNodeMotionOptimizer::Optimize() { + node_map_.reset(new NodeMap(optimized_graph_)); + FrameMap frame_map; + int num_frames; + TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, + &frame_map, &num_frames)); std::deque<int> worklist; - for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) { + for (auto iter = frame_map.begin(); iter != frame_map.end(); ++iter) { auto* node = iter->first; auto& frame_ids = iter->second; if (frame_ids.size() >= 3) { @@ -467,19 +429,82 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { return Status::OK(); } -Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { +std::vector<int> GetStackPushNodesToConvert( + const SimpleGraphView& graph_view, + const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) { + VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name(); + const std::unordered_set<string> op_types_to_traverse( + {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch", + "Identity", "RefIdentity"}); + std::vector<int> nodes_to_convert; + std::set<int> fanout; + graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout); + for (int fanout_idx : fanout) { + const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx); + VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name(); + if (IsStackPushOp(fanout_node)) { + nodes_to_convert.push_back(fanout_idx); + } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) || + op_types_to_traverse.find(fanout_node.op()) != + op_types_to_traverse.end()) { + continue; + } else if (!IsStackPopOp(fanout_node) || + (!graph_view.outputs(fanout_idx).empty() || + nodes_to_preserve.find(fanout_node.name()) != + nodes_to_preserve.end())) { + // The node is either a stack pop with consumers or something unexpected + // so we leave the graph alone. + nodes_to_convert.clear(); + break; + } + } + return nodes_to_convert; +} + +Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) { + const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve(); + const GraphDef& graph = item.graph; + *optimized_graph = graph; + NodeMap node_map(optimized_graph); + SimpleGraphView graph_view; + TF_RETURN_IF_ERROR(graph_view.Initialize(graph)); + for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) { + if (IsStackOp(graph.node(node_idx))) { + for (int push_node_idx : GetStackPushNodesToConvert( + graph_view, nodes_to_preserve, node_idx)) { + // We found push nodes without corresponding pops. Convert them to + // Identity passing the data through and add a control dependency from + // the op supplying the stack handle. + NodeDef* push_node = optimized_graph->mutable_node(push_node_idx); + VLOG(1) << "Converting " << push_node_idx << " : " + << push_node->DebugString(); + if (push_node->attr().count("swap_memory") != 0) { + push_node->mutable_attr()->erase("swap_memory"); + } + push_node->set_op("Identity"); + push_node->mutable_input()->SwapElements(0, 1); + const string ctrl_dep = ConstantFolding::AddControlDependency( + push_node->input(1), optimized_graph, &node_map); + push_node->set_input(1, ctrl_dep); + VLOG(1) << "After converting: " << push_node->DebugString(); + } + } + } + return Status::OK(); +} - TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph)); +} // namespace - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - optimized_graph_ = optimized_graph; - // Set up helper data structures. - node_map_.reset(new NodeMap(optimized_graph_)); - int num_frames; - TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, - &frame_map_, &num_frames)); - TF_RETURN_IF_ERROR(LoopInvariantNodeMotion()); +Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + // Set up helper data structures. + if (options_.enable_loop_invariant_node_motion) { + LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph); + TF_RETURN_IF_ERROR(linm_optimizer.Optimize()); + } + if (options_.enable_stack_push_removal) { + TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph)); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index c1b0321e4e..83c499bbe7 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -30,9 +30,13 @@ constexpr char kLoopOptimizer[] = "LoopOptimizer"; class LoopOptimizer : public GraphOptimizer { public: - LoopOptimizer() : opt_level_(RewriterConfig::ON) {} + LoopOptimizer() + : opt_level_(RewriterConfig::ON), + options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {} explicit LoopOptimizer(RewriterConfig::Toggle opt_level) - : opt_level_(opt_level) {} + : opt_level_(opt_level), + options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {} + ~LoopOptimizer() override {} string name() const override { return "loop_optimizer"; }; @@ -44,29 +48,24 @@ class LoopOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: - Status LoopInvariantNodeMotion(); - Status FindInvariantNodes(NodeDef* node); - Status RevertInvariantNodes(); - Status MoveInvariantNodes(const int frame_id); - Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs, - const int frame_id); - Status LINMHandleConst(NodeDef* node, const int num_outputs, - const int frame_id); - Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs); - - std::map<NodeDef*, int> invariant_nodes_; - std::set<int> empty_set_; - std::map<int, std::set<int>> frame_children_; - std::map<int, int> frame_parent_; - std::map<int, const NodeDef*> loop_cond_; - std::map<int, std::vector<NodeDef*>> invariant_enters_; - int new_enter_id_; - RewriterConfig::Toggle opt_level_; + friend class LoopOptimizerTest; + + // Granular control for loop optimizer stages. + struct LoopOptimizerOptions { + bool enable_loop_invariant_node_motion = false; + bool enable_stack_push_removal = true; + + static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { + LoopOptimizerOptions options; + if (opt_level == RewriterConfig::AGGRESSIVE) { + options.enable_loop_invariant_node_motion = true; + } + return options; + } + }; - std::unique_ptr<NodeMap> node_map_; - FrameMap frame_map_; - std::unique_ptr<GraphProperties> graph_properties_; - GraphDef* optimized_graph_; // Not owned. + RewriterConfig::Toggle opt_level_; + LoopOptimizerOptions options_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index a0bd335197..10ec544424 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -25,7 +25,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace { class LoopOptimizerTest : public GrapplerTest { protected: @@ -57,6 +56,23 @@ class LoopOptimizerTest : public GrapplerTest { attributes.emplace_back("T", type); AddNode(name, op, inputs, attributes, graph); } + + void DisableAllStages(LoopOptimizer* optimizer) { + LoopOptimizer::LoopOptimizerOptions options; + options.enable_loop_invariant_node_motion = false; + options.enable_stack_push_removal = false; + optimizer->options_ = options; + } + + void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.enable_loop_invariant_node_motion = true; + } + + void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.enable_stack_push_removal = true; + } }; TEST_F(LoopOptimizerTest, Basic) { @@ -81,7 +97,8 @@ TEST_F(LoopOptimizerTest, Basic) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -128,7 +145,8 @@ TEST_F(LoopOptimizerTest, Const) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -175,7 +193,8 @@ TEST_F(LoopOptimizerTest, ControlOutput) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -235,7 +254,8 @@ TEST_F(LoopOptimizerTest, NestedLoop1) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -302,7 +322,8 @@ TEST_F(LoopOptimizerTest, NestedLoop2) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -365,7 +386,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -429,7 +451,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + LoopOptimizer optimizer; + EnableOnlyLoopInvariantNodeMotion(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -475,6 +498,7 @@ TEST_F(LoopOptimizerTest, NoOp) { CHECK(fake_input.NextItem(&item)); LoopOptimizer optimizer; + EnableOnlyStackPushRemoval(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -504,6 +528,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) { AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph); LoopOptimizer optimizer; + EnableOnlyStackPushRemoval(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -534,6 +559,7 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { item.fetch.push_back("pop4"); LoopOptimizer optimizer; + EnableOnlyStackPushRemoval(&optimizer); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -563,6 +589,5 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { } } -} // namespace } // namespace grappler } // namespace tensorflow |