diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index f2277a9b79..38af7170b5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -185,6 +185,10 @@ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node, return false; } +bool SimplyReordersData(const NodeDef& node) { + return node.op() == "Transpose"; +} + // Follow a chain (through input(0)) of ops starting at `source->input(0)` as // long as they // 1. preserve the values of their first input, @@ -703,6 +707,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map->AddOutput(new_transpose->name(), new_cast->name()); new_nodes->push_back(new_transpose); + new_nodes->push_back(new_cast); // Add frame dependencies that the original node might have had. AddFrameControlDeps(node, {new_transpose, new_cast}, new_transpose->input(0), {new_transpose}, @@ -832,7 +837,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } - if (node->input_size() > 0 && IsAggregate(*node)) { + if (node->input_size() > 0 && IsAggregate(*node) && + !node_map->GetOutputs(node->name()).empty()) { // Discard aggregate nodes with a single input. if (node->input_size() == 1) { return node->input(0); @@ -853,7 +859,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( break; } } - if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) { + if (all_equal) { // 1. Create constant node with value N. const int N = node->input_size(); const auto type = GetDataTypeFromAttr(*node, "T"); @@ -879,6 +885,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->set_device(node->device()); SetDataTypeToAttr(type, "T", new_mul_node); node_map->AddNode(new_mul_node->name(), new_mul_node); + new_nodes->push_back(new_mul_node); new_mul_node->add_input(new_const_node->name()); node_map->AddOutput(new_const_node->name(), new_mul_node->name()); new_mul_node->add_input(node->input(0)); @@ -895,7 +902,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // where all the inputs are Mul nodes. This pattern occurs frequently in // regularization terms for the gradients during training. if (node->input_size() > 1 && IsAggregate(*node) && - node_map->GetNode(node->name() + "_hoist") == nullptr) { + !node_map->GetOutputs(node->name()).empty()) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set<string> common_factors; int i = 0; @@ -943,6 +950,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->set_name(new_mul_node->name() + "_hoist"); new_mul_node->set_input(0, common_factor); new_mul_node->set_input(1, new_add_node->name()); + new_nodes->push_back(new_mul_node); node_map->AddNode(new_mul_node->name(), new_mul_node); } } @@ -1007,9 +1015,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } // Fold Conj into Transpose or ConjugateTranspose. - if ((node->op() == "Conj" || node->op() == "Transpose" || - node->op() == "ConjugateTranspose") && - node_map->GetNode(node->name() + "_fused") == nullptr) { + if (node->op() == "Conj" || node->op() == "Transpose" || + node->op() == "ConjugateTranspose") { const NodeDef* input = node_map->GetNode(node->input(0)); const NodeDef* transpose_op = node->op() == "Conj" ? input : node; const NodeDef* conj_op = node->op() == "Conj" ? node : input; @@ -1042,14 +1049,10 @@ namespace { template <class T> class SetVector { public: - // Returns false if value already existed in the set, true otherwise. - bool PushBack(const T& value) { - if (!set_.insert(value).second) { - VLOG(2) << "Value " << value << " is already in the set."; - return false; - } + void PushBack(const T& value) { + CHECK(!Exists(value)) << "Value " << value << " is already in the set."; + set_.insert(value); vector_.push_back(value); - return true; } T PopBack() { @@ -1090,11 +1093,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps( } if (NodeName(simplified_tensor) != node->name()) { - // Always consider simplified_tensor for further optimizations. - const NodeDef* simplified_node = node_map.GetNode(simplified_tensor); - if (simplified_node != nullptr) { - nodes_to_simplify.PushBack(simplified_node); - } // When `node` is simplifed to another node rather than in-place, the // consumers of `node` are already redirected to `simplified_tensor`. // Re-push the consumers into `nodes_to_simplify` for further |