aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc34
1 files changed, 16 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index f2277a9b79..38af7170b5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -185,6 +185,10 @@ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
return false;
}
+bool SimplyReordersData(const NodeDef& node) {
+ return node.op() == "Transpose";
+}
+
// Follow a chain (through input(0)) of ops starting at `source->input(0)` as
// long as they
// 1. preserve the values of their first input,
@@ -703,6 +707,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map->AddOutput(new_transpose->name(), new_cast->name());
new_nodes->push_back(new_transpose);
+ new_nodes->push_back(new_cast);
// Add frame dependencies that the original node might have had.
AddFrameControlDeps(node, {new_transpose, new_cast},
new_transpose->input(0), {new_transpose},
@@ -832,7 +837,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
- if (node->input_size() > 0 && IsAggregate(*node)) {
+ if (node->input_size() > 0 && IsAggregate(*node) &&
+ !node_map->GetOutputs(node->name()).empty()) {
// Discard aggregate nodes with a single input.
if (node->input_size() == 1) {
return node->input(0);
@@ -853,7 +859,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
break;
}
}
- if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) {
+ if (all_equal) {
// 1. Create constant node with value N.
const int N = node->input_size();
const auto type = GetDataTypeFromAttr(*node, "T");
@@ -879,6 +885,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->set_device(node->device());
SetDataTypeToAttr(type, "T", new_mul_node);
node_map->AddNode(new_mul_node->name(), new_mul_node);
+ new_nodes->push_back(new_mul_node);
new_mul_node->add_input(new_const_node->name());
node_map->AddOutput(new_const_node->name(), new_mul_node->name());
new_mul_node->add_input(node->input(0));
@@ -895,7 +902,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// where all the inputs are Mul nodes. This pattern occurs frequently in
// regularization terms for the gradients during training.
if (node->input_size() > 1 && IsAggregate(*node) &&
- node_map->GetNode(node->name() + "_hoist") == nullptr) {
+ !node_map->GetOutputs(node->name()).empty()) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
int i = 0;
@@ -943,6 +950,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->set_name(new_mul_node->name() + "_hoist");
new_mul_node->set_input(0, common_factor);
new_mul_node->set_input(1, new_add_node->name());
+ new_nodes->push_back(new_mul_node);
node_map->AddNode(new_mul_node->name(), new_mul_node);
}
}
@@ -1007,9 +1015,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
// Fold Conj into Transpose or ConjugateTranspose.
- if ((node->op() == "Conj" || node->op() == "Transpose" ||
- node->op() == "ConjugateTranspose") &&
- node_map->GetNode(node->name() + "_fused") == nullptr) {
+ if (node->op() == "Conj" || node->op() == "Transpose" ||
+ node->op() == "ConjugateTranspose") {
const NodeDef* input = node_map->GetNode(node->input(0));
const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
const NodeDef* conj_op = node->op() == "Conj" ? node : input;
@@ -1042,14 +1049,10 @@ namespace {
template <class T>
class SetVector {
public:
- // Returns false if value already existed in the set, true otherwise.
- bool PushBack(const T& value) {
- if (!set_.insert(value).second) {
- VLOG(2) << "Value " << value << " is already in the set.";
- return false;
- }
+ void PushBack(const T& value) {
+ CHECK(!Exists(value)) << "Value " << value << " is already in the set.";
+ set_.insert(value);
vector_.push_back(value);
- return true;
}
T PopBack() {
@@ -1090,11 +1093,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(
}
if (NodeName(simplified_tensor) != node->name()) {
- // Always consider simplified_tensor for further optimizations.
- const NodeDef* simplified_node = node_map.GetNode(simplified_tensor);
- if (simplified_node != nullptr) {
- nodes_to_simplify.PushBack(simplified_node);
- }
// When `node` is simplifed to another node rather than in-place, the
// consumers of `node` are already redirected to `simplified_tensor`.
// Re-push the consumers into `nodes_to_simplify` for further