aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-06 16:00:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 16:03:12 -0700
commitd8ec179569514c068284c84540826d077a30485d (patch)
treebbf637c102a9cc9c98fd334857879cf42e984798
parentd017e6f030c06d4803897a0321144254ad563165 (diff)
Refactor LoopOptimizer:
* Put loop-invariant node motion in its own class. * Add granular control of which passes to run. Swap order of LINM and stack push removal. PiperOrigin-RevId: 191953537
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc247
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.h47
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer_test.cc43
3 files changed, 193 insertions, 144 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index a063dc3381..28ce2c7a55 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -16,18 +16,17 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include <algorithm>
+#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -46,74 +45,36 @@ namespace tensorflow {
namespace grappler {
namespace {
-std::vector<int> GetStackPushNodesToConvert(
- const SimpleGraphView& graph_view,
- const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
- VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
- const std::unordered_set<string> op_types_to_traverse(
- {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
- "Identity", "RefIdentity"});
- std::vector<int> nodes_to_convert;
- std::set<int> fanout;
- graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
- for (int fanout_idx : fanout) {
- const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
- VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
- if (IsStackPushOp(fanout_node)) {
- nodes_to_convert.push_back(fanout_idx);
- } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
- op_types_to_traverse.find(fanout_node.op()) !=
- op_types_to_traverse.end()) {
- continue;
- } else if (!IsStackPopOp(fanout_node) ||
- (!graph_view.outputs(fanout_idx).empty() ||
- nodes_to_preserve.find(fanout_node.name()) !=
- nodes_to_preserve.end())) {
- // The node is either a stack pop with consumers or something unexpected
- // so we leave the graph alone.
- nodes_to_convert.clear();
- break;
- }
- }
- return nodes_to_convert;
-}
+class LoopInvariantNodeMotionOptimizer {
+ public:
+ explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
+ : optimized_graph_(optimized_graph) {}
+ virtual ~LoopInvariantNodeMotionOptimizer() = default;
+ Status Optimize();
-Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
- const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
- const GraphDef& graph = item.graph;
- *optimized_graph = graph;
- NodeMap node_map(optimized_graph);
- SimpleGraphView graph_view;
- TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
- for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
- if (IsStackOp(graph.node(node_idx))) {
- for (int push_node_idx : GetStackPushNodesToConvert(
- graph_view, nodes_to_preserve, node_idx)) {
- // We found push nodes without corresponding pops. Convert them to
- // Identity passing the data through and add a control dependency from
- // the op supplying the stack handle.
- NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
- VLOG(1) << "Converting " << push_node_idx << " : "
- << push_node->DebugString();
- if (push_node->attr().count("swap_memory") != 0) {
- push_node->mutable_attr()->erase("swap_memory");
- }
- push_node->set_op("Identity");
- push_node->mutable_input()->SwapElements(0, 1);
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- push_node->input(1), optimized_graph, &node_map);
- push_node->set_input(1, ctrl_dep);
- VLOG(1) << "After converting: " << push_node->DebugString();
- }
- }
- }
- return Status::OK();
-}
+ private:
+ Status FindInvariantNodes(NodeDef* node);
+ Status RevertInvariantNodes();
+ Status MoveInvariantNodes(const int frame_id);
+ Status HandleInvariantNode(NodeDef* node, const int num_outputs,
+ const int frame_id);
+ Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
+ Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
-} // namespace
+ GraphDef* optimized_graph_; // Not owned.
+ std::unique_ptr<NodeMap> node_map_;
+ std::map<NodeDef*, int> invariant_nodes_;
+ std::set<int> empty_set_;
+ // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
+ std::map<int, std::set<int>> frame_children_;
+ std::map<int, int> frame_parent_;
+ std::map<int, const NodeDef*> loop_cond_;
+ std::map<int, std::vector<NodeDef*>> invariant_enters_;
+ int new_enter_id_;
+};
-Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
- const int num_outputs) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
+ NodeDef* node, const int num_outputs) {
auto consumers = node_map_->GetOutputs(node->name());
std::vector<string> enter_control_inputs;
string enter_input;
@@ -142,8 +103,9 @@ Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::LINMHandleConst(NodeDef* node,
- const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
+ const int num_outputs,
+ const int frame_id) {
NodeDef* const_node;
if (num_outputs == 0) {
// all successor nodes are invariant
@@ -185,8 +147,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
int parent_id = parent_it->second;
auto loop_cond_it = loop_cond_.find(parent_id);
if (loop_cond_it == loop_cond_.end()) {
- return errors::InvalidArgument(
- "Frame ", frame_id, " doesn't have a LoopCond node");
+ return errors::InvalidArgument("Frame ", frame_id,
+ " doesn't have a LoopCond node");
}
auto& loop_cond_name = loop_cond_it->second->name();
NodeDef* switch_node = nullptr;
@@ -197,9 +159,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
}
}
if (!switch_node) {
- return errors::InvalidArgument(
- "LoopCond node of Frame ", frame_id,
- " doesn't connect to any Switch node");
+ return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
+ " doesn't connect to any Switch node");
}
string switch_output = StrCat(switch_node->name(), ":1");
const string ctrl_dep = ConstantFolding::AddControlDependency(
@@ -210,8 +171,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
- const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
+ NodeDef* node, const int num_outputs, const int frame_id) {
// have to remove control inputs to the invariant node from the same frame
// when moving this node out of this frame
for (int i = 0; i < node->input_size(); ++i) {
@@ -228,16 +189,14 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
DataTypeVector output_types;
OpRegistryInterface* op_registry = OpRegistry::Global();
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(
- op_registry->LookUp(node->op(), &op_reg_data));
- TF_RETURN_IF_ERROR(
- InOutTypesForNode(*node, op_reg_data->op_def,
- &input_types, &output_types));
+ TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
+ &output_types));
auto consumers = node_map_->GetOutputs(node->name());
string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
- int piterations = invariant_enters_[frame_id][0]
- ->attr().at("parallel_iterations").i();
+ int piterations =
+ invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
for (auto* consumer : consumers) {
if (!invariant_nodes_.count(consumer)) {
for (int i = 0; i < consumer->input_size(); ++i) {
@@ -281,28 +240,27 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::MoveInvariantNodes(const int frame_id) {
- for (auto iter = invariant_nodes_.begin();
- iter != invariant_nodes_.end(); ++iter) {
+Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
+ const int frame_id) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
+ ++iter) {
auto* invariant_node = iter->first;
const int num_outputs = iter->second;
if (IsEnter(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleInvariantEnter(invariant_node, num_outputs));
+ TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
} else if (IsConstant(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleConst(invariant_node, num_outputs, frame_id));
+ TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
} else {
TF_RETURN_IF_ERROR(
- LINMHandleInvariantNode(invariant_node, num_outputs, frame_id));
+ HandleInvariantNode(invariant_node, num_outputs, frame_id));
}
}
return Status::OK();
}
-Status LoopOptimizer::RevertInvariantNodes() {
+Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
std::deque<const NodeDef*> reverted_nodes;
- for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
bool erased = false;
const auto* node = iter->first;
if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
@@ -331,8 +289,8 @@ Status LoopOptimizer::RevertInvariantNodes() {
auto* producer = node_map_->GetNode(input);
auto iter = invariant_nodes_.find(producer);
if (iter != invariant_nodes_.end()) {
- if (IsControlInput(input) &&
- !IsConstant(*producer) && !IsEnter(*producer)) {
+ if (IsControlInput(input) && !IsConstant(*producer) &&
+ !IsEnter(*producer)) {
reverted_nodes.push_back(producer);
invariant_nodes_.erase(iter);
} else {
@@ -357,12 +315,11 @@ Status LoopOptimizer::RevertInvariantNodes() {
return Status::OK();
}
-Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
+Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) {
auto consumers = node_map_->GetOutputs(node->name());
invariant_nodes_.insert(std::make_pair(node, consumers.size()));
for (auto* consumer : consumers) {
- if (invariant_nodes_.count(consumer) ||
- ModifiesFrameInfo(*consumer)) {
+ if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
continue;
}
bool is_invariant = true;
@@ -399,9 +356,14 @@ Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
return Status::OK();
}
-Status LoopOptimizer::LoopInvariantNodeMotion() {
+Status LoopInvariantNodeMotionOptimizer::Optimize() {
+ node_map_.reset(new NodeMap(optimized_graph_));
+ FrameMap frame_map;
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
+ &frame_map, &num_frames));
std::deque<int> worklist;
- for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) {
+ for (auto iter = frame_map.begin(); iter != frame_map.end(); ++iter) {
auto* node = iter->first;
auto& frame_ids = iter->second;
if (frame_ids.size() >= 3) {
@@ -467,19 +429,82 @@ Status LoopOptimizer::LoopInvariantNodeMotion() {
return Status::OK();
}
-Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+std::vector<int> GetStackPushNodesToConvert(
+ const SimpleGraphView& graph_view,
+ const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
+ VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
+ const std::unordered_set<string> op_types_to_traverse(
+ {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
+ "Identity", "RefIdentity"});
+ std::vector<int> nodes_to_convert;
+ std::set<int> fanout;
+ graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
+ for (int fanout_idx : fanout) {
+ const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
+ VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
+ if (IsStackPushOp(fanout_node)) {
+ nodes_to_convert.push_back(fanout_idx);
+ } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
+ op_types_to_traverse.find(fanout_node.op()) !=
+ op_types_to_traverse.end()) {
+ continue;
+ } else if (!IsStackPopOp(fanout_node) ||
+ (!graph_view.outputs(fanout_idx).empty() ||
+ nodes_to_preserve.find(fanout_node.name()) !=
+ nodes_to_preserve.end())) {
+ // The node is either a stack pop with consumers or something unexpected
+ // so we leave the graph alone.
+ nodes_to_convert.clear();
+ break;
+ }
+ }
+ return nodes_to_convert;
+}
+
+Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
+ const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
+ const GraphDef& graph = item.graph;
+ *optimized_graph = graph;
+ NodeMap node_map(optimized_graph);
+ SimpleGraphView graph_view;
+ TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
+ for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
+ if (IsStackOp(graph.node(node_idx))) {
+ for (int push_node_idx : GetStackPushNodesToConvert(
+ graph_view, nodes_to_preserve, node_idx)) {
+ // We found push nodes without corresponding pops. Convert them to
+ // Identity passing the data through and add a control dependency from
+ // the op supplying the stack handle.
+ NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
+ VLOG(1) << "Converting " << push_node_idx << " : "
+ << push_node->DebugString();
+ if (push_node->attr().count("swap_memory") != 0) {
+ push_node->mutable_attr()->erase("swap_memory");
+ }
+ push_node->set_op("Identity");
+ push_node->mutable_input()->SwapElements(0, 1);
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ push_node->input(1), optimized_graph, &node_map);
+ push_node->set_input(1, ctrl_dep);
+ VLOG(1) << "After converting: " << push_node->DebugString();
+ }
+ }
+ }
+ return Status::OK();
+}
- TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
+} // namespace
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- optimized_graph_ = optimized_graph;
- // Set up helper data structures.
- node_map_.reset(new NodeMap(optimized_graph_));
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
- TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ // Set up helper data structures.
+ if (options_.enable_loop_invariant_node_motion) {
+ LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
+ TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
+ }
+ if (options_.enable_stack_push_removal) {
+ TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h
index c1b0321e4e..83c499bbe7 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h
@@ -30,9 +30,13 @@ constexpr char kLoopOptimizer[] = "LoopOptimizer";
class LoopOptimizer : public GraphOptimizer {
public:
- LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
+ LoopOptimizer()
+ : opt_level_(RewriterConfig::ON),
+ options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
explicit LoopOptimizer(RewriterConfig::Toggle opt_level)
- : opt_level_(opt_level) {}
+ : opt_level_(opt_level),
+ options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
+
~LoopOptimizer() override {}
string name() const override { return "loop_optimizer"; };
@@ -44,29 +48,24 @@ class LoopOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
- Status LoopInvariantNodeMotion();
- Status FindInvariantNodes(NodeDef* node);
- Status RevertInvariantNodes();
- Status MoveInvariantNodes(const int frame_id);
- Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs,
- const int frame_id);
- Status LINMHandleConst(NodeDef* node, const int num_outputs,
- const int frame_id);
- Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs);
-
- std::map<NodeDef*, int> invariant_nodes_;
- std::set<int> empty_set_;
- std::map<int, std::set<int>> frame_children_;
- std::map<int, int> frame_parent_;
- std::map<int, const NodeDef*> loop_cond_;
- std::map<int, std::vector<NodeDef*>> invariant_enters_;
- int new_enter_id_;
- RewriterConfig::Toggle opt_level_;
+ friend class LoopOptimizerTest;
+
+ // Granular control for loop optimizer stages.
+ struct LoopOptimizerOptions {
+ bool enable_loop_invariant_node_motion = false;
+ bool enable_stack_push_removal = true;
+
+ static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
+ LoopOptimizerOptions options;
+ if (opt_level == RewriterConfig::AGGRESSIVE) {
+ options.enable_loop_invariant_node_motion = true;
+ }
+ return options;
+ }
+ };
- std::unique_ptr<NodeMap> node_map_;
- FrameMap frame_map_;
- std::unique_ptr<GraphProperties> graph_properties_;
- GraphDef* optimized_graph_; // Not owned.
+ RewriterConfig::Toggle opt_level_;
+ LoopOptimizerOptions options_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
index a0bd335197..10ec544424 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace {
class LoopOptimizerTest : public GrapplerTest {
protected:
@@ -57,6 +56,23 @@ class LoopOptimizerTest : public GrapplerTest {
attributes.emplace_back("T", type);
AddNode(name, op, inputs, attributes, graph);
}
+
+ void DisableAllStages(LoopOptimizer* optimizer) {
+ LoopOptimizer::LoopOptimizerOptions options;
+ options.enable_loop_invariant_node_motion = false;
+ options.enable_stack_push_removal = false;
+ optimizer->options_ = options;
+ }
+
+ void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_loop_invariant_node_motion = true;
+ }
+
+ void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_stack_push_removal = true;
+ }
};
TEST_F(LoopOptimizerTest, Basic) {
@@ -81,7 +97,8 @@ TEST_F(LoopOptimizerTest, Basic) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -128,7 +145,8 @@ TEST_F(LoopOptimizerTest, Const) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -175,7 +193,8 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -235,7 +254,8 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -302,7 +322,8 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -365,7 +386,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -429,7 +451,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -475,6 +498,7 @@ TEST_F(LoopOptimizerTest, NoOp) {
CHECK(fake_input.NextItem(&item));
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -504,6 +528,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -534,6 +559,7 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
item.fetch.push_back("pop4");
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -563,6 +589,5 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
}
}
-} // namespace
} // namespace grappler
} // namespace tensorflow