aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/loop_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/loop_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc381
1 files changed, 379 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 131466430e..244653504d 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -15,19 +15,31 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
+#include <algorithm>
+#include <limits>
#include <unordered_map>
#include <unordered_set>
+#include <vector>
+#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+using tensorflow::strings::StrCat;
namespace tensorflow {
namespace grappler {
@@ -94,10 +106,375 @@ Status RemoveStackOps(const GraphDef& graph, GraphDef* optimized_graph) {
} // namespace
+Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
+ const int num_outputs) {
+ auto consumers = node_map_->GetOutputs(node->name());
+ std::vector<string> enter_control_inputs;
+ string enter_input;
+ for (auto& input : node->input()) {
+ if (IsControlInput(input)) {
+ enter_control_inputs.push_back(input);
+ } else {
+ enter_input = input;
+ }
+ }
+ for (auto* consumer : consumers) {
+ if (invariant_nodes_.count(consumer)) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (NodeName(consumer->input(i)) == node->name()) {
+ consumer->set_input(i, enter_input);
+ node_map_->AddOutput(NodeName(enter_input), consumer->name());
+ node_map_->RemoveOutput(node->name(), consumer->name());
+ }
+ }
+ for (auto& control_input : enter_control_inputs) {
+ consumer->add_input(control_input);
+ node_map_->AddOutput(NodeName(control_input), consumer->name());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::LINMHandleConst(NodeDef* node,
+ const int num_outputs, const int frame_id) {
+ NodeDef* const_node;
+ if (num_outputs == 0) {
+ // all successor nodes are invariant
+ // Remove the control inputs from this frame to the const node,
+ // when moving it out of the frame (in parent frame)
+ const_node = node;
+ node_map_->RemoveInputs(node->name());
+ node->clear_input();
+ } else {
+ // some successor nodes are variant
+ // Have to keep the const node in the frame,
+ // so create a new one outside the frame (in parent frame)
+ const_node = optimized_graph_->add_node();
+ const_node->set_name(AddPrefixToNodeName(node->name(), kLoopOptimizer));
+ const_node->set_op("Const");
+ const_node->set_device(node->device());
+ *const_node->mutable_attr() = node->attr();
+ node_map_->AddNode(const_node->name(), const_node);
+ auto consumers = node_map_->GetOutputs(node->name());
+ for (auto* consumer : consumers) {
+ if (invariant_nodes_.count(consumer)) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (NodeName(consumer->input(i)) == node->name()) {
+ if (IsControlInput(consumer->input(i))) {
+ *consumer->mutable_input(i) = AsControlDependency(*const_node);
+ } else {
+ *consumer->mutable_input(i) = const_node->name();
+ }
+ node_map_->AddOutput(const_node->name(), consumer->name());
+ node_map_->RemoveOutput(node->name(), consumer->name());
+ }
+ }
+ }
+ }
+ }
+ // add a control input from the parent frame
+ auto parent_it = frame_parent_.find(frame_id);
+ if (parent_it != frame_parent_.end()) {
+ int parent_id = parent_it->second;
+ auto loop_cond_it = loop_cond_.find(parent_id);
+ if (loop_cond_it == loop_cond_.end()) {
+ return errors::InvalidArgument(
+ "Frame ", frame_id, " doesn't have a LoopCond node");
+ }
+ auto& loop_cond_name = loop_cond_it->second->name();
+ NodeDef* switch_node = nullptr;
+ for (auto* node : node_map_->GetOutputs(loop_cond_name)) {
+ if (node->op() == "Switch") {
+ switch_node = node;
+ break;
+ }
+ }
+ if (!switch_node) {
+ return errors::InvalidArgument(
+ "LoopCond node of Frame ", frame_id,
+ " doesn't connect to any Switch node");
+ }
+ string switch_output = StrCat(switch_node->name(), ":1");
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ switch_output, optimized_graph_, node_map_.get());
+ const_node->add_input(ctrl_dep);
+ node_map_->AddOutput(NodeName(ctrl_dep), const_node->name());
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
+ const int num_outputs, const int frame_id) {
+ // have to remove control inputs to the invariant node from the same frame
+ // when moving this node out of this frame
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) {
+ node->mutable_input()->SwapElements(i, node->input_size() - 1);
+ node->mutable_input()->RemoveLast();
+ }
+ }
+ if (num_outputs == 0) {
+ return Status::OK();
+ }
+
+ DataTypeVector input_types;
+ DataTypeVector output_types;
+ OpRegistryInterface* op_registry = OpRegistry::Global();
+ const OpRegistrationData* op_reg_data = nullptr;
+ TF_RETURN_IF_ERROR(
+ op_registry->LookUp(node->op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(
+ InOutTypesForNode(*node, op_reg_data->op_def,
+ &input_types, &output_types));
+
+ auto consumers = node_map_->GetOutputs(node->name());
+ string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
+ int piterations = invariant_enters_[frame_id][0]
+ ->attr().at("parallel_iterations").i();
+ for (auto* consumer : consumers) {
+ if (!invariant_nodes_.count(consumer)) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ int port;
+ string node_name = ParseNodeName(consumer->input(i), &port);
+ if (node_name != node->name()) {
+ continue;
+ }
+ if (port < 0) {
+ return errors::InvalidArgument(
+ "Invariant node should not have control outputs "
+ "to variant node");
+ }
+ DataType output_type = output_types[port];
+ NodeDef* new_enter = optimized_graph_->add_node();
+ new_enter->set_op("Enter");
+ new_enter->set_device(node->device());
+ new_enter->set_name(AddPrefixToNodeName(
+ StrCat(fname, "_enter_", new_enter_id_++), kLoopOptimizer));
+ AttrValue data_type;
+ data_type.set_type(output_type);
+ new_enter->mutable_attr()->insert({"T", data_type});
+ AttrValue frame_name;
+ frame_name.set_s(fname);
+ new_enter->mutable_attr()->insert({"frame_name", frame_name});
+ AttrValue is_const;
+ is_const.set_b(true);
+ new_enter->mutable_attr()->insert({"is_constant", is_const});
+ AttrValue parallel_iterations;
+ parallel_iterations.set_i(piterations);
+ new_enter->mutable_attr()->insert(
+ {"parallel_iterations", parallel_iterations});
+ new_enter->add_input(consumer->input(i));
+ *consumer->mutable_input(i) = new_enter->name();
+ node_map_->AddNode(new_enter->name(), new_enter);
+ node_map_->AddOutput(node->name(), new_enter->name());
+ node_map_->AddOutput(new_enter->name(), consumer->name());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::MoveInvariantNodes(const int frame_id) {
+ for (auto iter = invariant_nodes_.begin();
+ iter != invariant_nodes_.end(); ++iter) {
+ auto* invariant_node = iter->first;
+ const int num_outputs = iter->second;
+ if (IsEnter(*invariant_node)) {
+ TF_RETURN_IF_ERROR(
+ LINMHandleInvariantEnter(invariant_node, num_outputs));
+ } else if (IsConstant(*invariant_node)) {
+ TF_RETURN_IF_ERROR(
+ LINMHandleConst(invariant_node, num_outputs, frame_id));
+ } else {
+ TF_RETURN_IF_ERROR(
+ LINMHandleInvariantNode(invariant_node, num_outputs, frame_id));
+ }
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::RevertInvariantNodes() {
+ std::deque<const NodeDef*> reverted_nodes;
+ for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
+ bool erased = false;
+ const auto* node = iter->first;
+ if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
+ auto& consumers = node_map_->GetOutputs(node->name());
+ for (auto* consumer : consumers) {
+ if (!invariant_nodes_.count(consumer)) {
+ for (const auto& input : consumer->input()) {
+ if (IsControlInput(input) && NodeName(input) == node->name()) {
+ reverted_nodes.push_back(node);
+ invariant_nodes_.erase(iter++);
+ erased = true;
+ break;
+ }
+ }
+ if (erased) break;
+ }
+ }
+ }
+ if (!erased) ++iter;
+ }
+ while (!reverted_nodes.empty()) {
+ const auto* node = reverted_nodes.front();
+ reverted_nodes.pop_front();
+ std::set<NodeDef*> producers;
+ for (const auto& input : node->input()) {
+ auto* producer = node_map_->GetNode(input);
+ auto iter = invariant_nodes_.find(producer);
+ if (iter != invariant_nodes_.end()) {
+ if (IsControlInput(input) &&
+ !IsConstant(*producer) && !IsEnter(*producer)) {
+ reverted_nodes.push_back(producer);
+ invariant_nodes_.erase(iter);
+ } else {
+ producers.insert(producer);
+ }
+ }
+ }
+ for (auto* producer : producers) {
+ auto iter = invariant_nodes_.find(producer);
+ if (iter != invariant_nodes_.end()) {
+ ++iter->second;
+ }
+ }
+ for (auto* consumer : node_map_->GetOutputs(node->name())) {
+ auto iter = invariant_nodes_.find(consumer);
+ if (iter != invariant_nodes_.end()) {
+ reverted_nodes.push_back(consumer);
+ invariant_nodes_.erase(iter);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
+ auto consumers = node_map_->GetOutputs(node->name());
+ invariant_nodes_.insert(std::make_pair(node, consumers.size()));
+ for (auto* consumer : consumers) {
+ if (invariant_nodes_.count(consumer) ||
+ ModifiesFrameInfo(*consumer)) {
+ continue;
+ }
+ bool is_invariant = true;
+ for (const auto& input : consumer->input()) {
+ if (!IsControlInput(input)) {
+ const auto& name = NodeName(input);
+ auto* producer = node_map_->GetNode(name);
+ if (!invariant_nodes_.count(producer)) {
+ if (IsConstant(*producer)) {
+ invariant_nodes_.insert(
+ std::make_pair(producer, node_map_->GetOutputs(name).size()));
+ } else {
+ is_invariant = false;
+ break;
+ }
+ }
+ }
+ }
+ if (is_invariant) {
+ std::set<NodeDef*> producers;
+ for (const auto& input : consumer->input()) {
+ auto* producer = node_map_->GetNode(input);
+ producers.insert(producer);
+ }
+ for (auto* producer : producers) {
+ auto iter = invariant_nodes_.find(producer);
+ if (iter != invariant_nodes_.end()) {
+ --iter->second;
+ }
+ }
+ TF_RETURN_IF_ERROR(FindInvariantNodes(consumer));
+ }
+ }
+ return Status::OK();
+}
+
+Status LoopOptimizer::LoopInvariantNodeMotion() {
+ std::deque<int> worklist;
+ for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) {
+ auto* node = iter->first;
+ auto& frame_ids = iter->second;
+ if (frame_ids.size() >= 3) {
+ for (unsigned int i = 1; i < frame_ids.size() - 1; ++i) {
+ frame_parent_[frame_ids[i]] = frame_ids[i - 1];
+ frame_children_[frame_ids[i]].insert(frame_ids[i + 1]);
+ }
+ }
+ if (frame_ids.size() >= 2) {
+ frame_children_[frame_ids[0]].insert(frame_ids[1]);
+ frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
+ }
+ if (!frame_ids.empty()) {
+ frame_children_.insert(std::make_pair(frame_ids.back(), empty_set_));
+ if (node->op() == "LoopCond") {
+ if (loop_cond_.count(frame_ids.back())) {
+ return errors::InvalidArgument(
+ "Loop ", frame_ids.back(),
+ " has more than one LoopCond node: ", node->name(), " and ",
+ loop_cond_[frame_ids.back()]->name());
+ }
+ loop_cond_[frame_ids.back()] = node;
+ }
+ if (IsEnter(*node) && node->attr().at("is_constant").b()) {
+ invariant_enters_[frame_ids.back()].push_back(
+ const_cast<NodeDef*>(node));
+ }
+ }
+ }
+
+ for (auto it = frame_children_.begin(); it != frame_children_.end(); ++it) {
+ if (it->second.empty()) {
+ worklist.push_back(it->first);
+ }
+ }
+
+ while (!worklist.empty()) {
+ int frame_id = worklist.front();
+ new_enter_id_ = 0;
+ worklist.pop_front();
+ auto parent_it = frame_parent_.find(frame_id);
+ if (parent_it != frame_parent_.end()) {
+ int parent_id = parent_it->second;
+ frame_children_[parent_id].erase(frame_id);
+ if (frame_children_[parent_id].empty()) {
+ worklist.push_back(parent_id);
+ }
+ }
+
+ if (invariant_enters_[frame_id].empty()) {
+ continue;
+ }
+ invariant_nodes_.clear();
+ for (auto* enter : invariant_enters_[frame_id]) {
+ TF_RETURN_IF_ERROR(FindInvariantNodes(enter));
+ }
+
+ // revert invariant nodes that have control outputs to variant nodes
+ TF_RETURN_IF_ERROR(RevertInvariantNodes());
+
+ TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id));
+ }
+ return Status::OK();
+}
+
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- Status status = RemoveStackOps(item.graph, optimized_graph);
- return status;
+ TF_RETURN_IF_ERROR(RemoveStackOps(item.graph, optimized_graph));
+
+ optimized_graph_ = optimized_graph;
+
+ // Set up helper data structures.
+ node_map_.reset(new NodeMap(optimized_graph_));
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
+ &frame_map_, &num_frames));
+
+ TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+ return Status::OK();
}
void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,