aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc108
1 files changed, 68 insertions, 40 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index adfae2e1a3..adef75f63e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -38,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tensor_coding.h"
@@ -254,6 +256,17 @@ NodeDef* GetTailOfValuePreservingChain(
is_value_preserving_non_branching);
}
+NodeDef* GetTailOfIdempotentChain(
+ const NodeDef& node, const NodeMap& node_map,
+ const std::unordered_set<string>& nodes_to_preserve) {
+ auto is_idempotent_non_branching = [&](const NodeDef& node) {
+ return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
+ IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
+ };
+ return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
+ is_idempotent_non_branching);
+}
+
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -270,7 +283,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
const ArithmeticOptimizerContext ctx_ext)
: GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
ctx_ext_(ctx_ext) {}
- virtual ~ArithmeticOptimizerStage() = default;
+ ~ArithmeticOptimizerStage() override = default;
protected:
// Simplification graph rewrite can create additional nodes that are inputs
@@ -1149,21 +1162,27 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
public:
explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
+ const ArithmeticOptimizerContext& ctx_ext,
+ RewriterConfig::Toggle opt_level)
+ : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext),
+ opt_level_(opt_level) {}
~RemoveIdentityTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsTranspose(*node) || IsConjugateTranspose(*node);
}
- // TODO(rmlarsen): Forward control dependencies on the bypassed
- // transpose nodes.
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+ NodeDef* tail = node;
+ // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
+ NodeDef* first_transpose;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
- NodeDef* input;
- TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
NodeDef* node_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
if (!IsConstant(*node_perm)) {
@@ -1171,17 +1190,30 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
std::vector<int64> node_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
- if (input->op() == node->op()) {
+ if (first_transpose->op() == node->op()) {
// Remove pairs of transposes that cancel each other.
- NodeDef* input_perm;
- TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
- if (!IsConstant(*input_perm)) {
+ NodeDef* first_transpose_perm;
+ TF_RETURN_IF_ERROR(
+ GetInputNode(first_transpose->input(1), &first_transpose_perm));
+ if (!IsConstant(*first_transpose_perm)) {
return Status::OK();
}
- std::vector<int64> input_perm_values;
- TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
- if (AreInversePermutations(node_perm_values, input_perm_values)) {
- *simplified_node_name = input->input(0);
+ std::vector<int64> first_transpose_perm_values;
+ TF_RETURN_IF_ERROR(
+ GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
+ if (AreInversePermutations(node_perm_values,
+ first_transpose_perm_values)) {
+ if (tail == node) {
+ // Bypass adjacent pair.
+ *simplified_node_name = first_transpose->input(0);
+ } else {
+ // Bypass pair connected through chain.
+ tail->set_input(0, first_transpose->input(0));
+ ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
+ first_transpose->input(0));
+ ForwardControlDependencies(tail, {first_transpose});
+ *simplified_node_name = node->input(0);
+ }
}
} else {
// Remove simple identity transposes.
@@ -1231,6 +1263,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
return true;
}
+
+ RewriterConfig::Toggle opt_level_;
};
// Remove redundant Bitcasts.
@@ -1752,7 +1786,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
class UniqueNodes {
public:
NodeDef* FindOrAddRepresentative(NodeDef* node) {
- std::size_t sig = ComputeSignature(*node);
+ uint64 sig = ComputeSignature(*node);
std::vector<NodeDef*>& candidates = rep_[sig];
for (auto& candidate : candidates) {
if (SameNode(*candidate, *node)) {
@@ -1764,26 +1798,25 @@ class UniqueNodes {
}
private:
- std::size_t ComputeSignature(const NodeDef& node) const;
+ uint64 ComputeSignature(const NodeDef& node) const;
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
- std::unordered_map<std::size_t, std::vector<NodeDef*>> rep_;
+ std::unordered_map<uint64, std::vector<NodeDef*>> rep_;
};
-std::size_t UniqueNodes::ComputeSignature(const NodeDef& node) const {
- std::size_t h = std::hash<string>{}(node.op());
- h ^= std::hash<string>{}(node.device());
+uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
+ uint64 h = Hash64(node.op());
+ h = Hash64Combine(Hash64(node.device()), h);
+
for (const auto& input : node.input()) {
int pos;
string node_name = ParseNodeName(input, &pos);
- h ^= std::hash<string>{}(node_name);
- h ^= static_cast<std::size_t>(pos);
+ h = Hash64CombineUnordered(Hash64(node_name), h);
+ h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
- h ^= std::hash<string>{}(attr.first);
- string tmp;
- attr.second.AppendToString(&tmp);
- h ^= std::hash<string>{}(tmp);
+ h = Hash64CombineUnordered(Hash64(attr.first), h);
+ h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
}
return h;
}
@@ -1839,17 +1872,8 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
}
for (const auto& attr1 : node1.attr()) {
auto it = node2.attr().find(attr1.first);
- if (it == node2.attr().end()) {
- return false;
- }
- const auto& attr2 = *it;
- string val1;
- attr1.second.AppendToString(&val1);
- string val2;
- attr2.second.AppendToString(&val2);
- if (val1 != val2) {
- return false;
- }
+ if (it == node2.attr().end()) return false;
+ if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
}
return true;
@@ -2233,6 +2257,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_square_node->set_input(i - 1, new_square_node->input(i));
}
new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ node_map_->AddOutput(NodeName(input), new_square_node->name());
+ }
return new_square_node->name();
}
}
@@ -2398,7 +2425,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.minimize_broadcasts && can_use_shapes)
pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
if (options_.remove_identity_transpose && can_use_shapes)
- pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
+ pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext, opt_level_);
if (options_.remove_redundant_bitcast)
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
@@ -2491,7 +2518,8 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
graph_properties_.reset(new GraphProperties(optimized_item));
- const Status status = graph_properties_->InferStatically(false);
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ const Status status = graph_properties_->InferStatically(assume_valid_feeds);
const bool can_use_shapes = status.ok();
if (!can_use_shapes) {
VLOG(1) << "Shape inference failed." << status.error_message();