diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 108 |
1 files changed, 68 insertions, 40 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index adfae2e1a3..adef75f63e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tensor_coding.h" @@ -254,6 +256,17 @@ NodeDef* GetTailOfValuePreservingChain( is_value_preserving_non_branching); } +NodeDef* GetTailOfIdempotentChain( + const NodeDef& node, const NodeMap& node_map, + const std::unordered_set<string>& nodes_to_preserve) { + auto is_idempotent_non_branching = [&](const NodeDef& node) { + return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() && + IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1; + }; + return GetTailOfChain(node, node_map, /*follow_control_input=*/false, + is_idempotent_non_branching); +} + // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify) @@ -270,7 +283,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { const ArithmeticOptimizerContext ctx_ext) : GraphOptimizerStage("ArithmeticOptimizer", name, ctx), ctx_ext_(ctx_ext) {} - virtual ~ArithmeticOptimizerStage() = default; + ~ArithmeticOptimizerStage() override = default; protected: // Simplification graph rewrite can create additional nodes that are inputs @@ -1149,21 +1162,27 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { class RemoveIdentityTranspose : public ArithmeticOptimizerStage { public: explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx, - const ArithmeticOptimizerContext& ctx_ext) - : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {} + const ArithmeticOptimizerContext& ctx_ext, + RewriterConfig::Toggle opt_level) + : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext), + opt_level_(opt_level) {} ~RemoveIdentityTranspose() override = default; bool IsSupported(const NodeDef* node) const override { return IsTranspose(*node) || IsConjugateTranspose(*node); } - // TODO(rmlarsen): Forward control dependencies on the bypassed - // transpose nodes. Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); + NodeDef* tail = node; + // TODO(rmlarsen): Enable in regular mode after May 15, 2018. + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); + } + NodeDef* first_transpose; + TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); - NodeDef* input; - TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); NodeDef* node_perm; TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm)); if (!IsConstant(*node_perm)) { @@ -1171,17 +1190,30 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { } std::vector<int64> node_perm_values; TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values)); - if (input->op() == node->op()) { + if (first_transpose->op() == node->op()) { // Remove pairs of transposes that cancel each other. - NodeDef* input_perm; - TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm)); - if (!IsConstant(*input_perm)) { + NodeDef* first_transpose_perm; + TF_RETURN_IF_ERROR( + GetInputNode(first_transpose->input(1), &first_transpose_perm)); + if (!IsConstant(*first_transpose_perm)) { return Status::OK(); } - std::vector<int64> input_perm_values; - TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values)); - if (AreInversePermutations(node_perm_values, input_perm_values)) { - *simplified_node_name = input->input(0); + std::vector<int64> first_transpose_perm_values; + TF_RETURN_IF_ERROR( + GetPermutation(*first_transpose_perm, &first_transpose_perm_values)); + if (AreInversePermutations(node_perm_values, + first_transpose_perm_values)) { + if (tail == node) { + // Bypass adjacent pair. + *simplified_node_name = first_transpose->input(0); + } else { + // Bypass pair connected through chain. + tail->set_input(0, first_transpose->input(0)); + ctx().node_map->UpdateInput(tail->name(), first_transpose->name(), + first_transpose->input(0)); + ForwardControlDependencies(tail, {first_transpose}); + *simplified_node_name = node->input(0); + } } } else { // Remove simple identity transposes. @@ -1231,6 +1263,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { } return true; } + + RewriterConfig::Toggle opt_level_; }; // Remove redundant Bitcasts. @@ -1752,7 +1786,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { class UniqueNodes { public: NodeDef* FindOrAddRepresentative(NodeDef* node) { - std::size_t sig = ComputeSignature(*node); + uint64 sig = ComputeSignature(*node); std::vector<NodeDef*>& candidates = rep_[sig]; for (auto& candidate : candidates) { if (SameNode(*candidate, *node)) { @@ -1764,26 +1798,25 @@ class UniqueNodes { } private: - std::size_t ComputeSignature(const NodeDef& node) const; + uint64 ComputeSignature(const NodeDef& node) const; bool SameNode(const NodeDef& node1, const NodeDef& node2) const; - std::unordered_map<std::size_t, std::vector<NodeDef*>> rep_; + std::unordered_map<uint64, std::vector<NodeDef*>> rep_; }; -std::size_t UniqueNodes::ComputeSignature(const NodeDef& node) const { - std::size_t h = std::hash<string>{}(node.op()); - h ^= std::hash<string>{}(node.device()); +uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { + uint64 h = Hash64(node.op()); + h = Hash64Combine(Hash64(node.device()), h); + for (const auto& input : node.input()) { int pos; string node_name = ParseNodeName(input, &pos); - h ^= std::hash<string>{}(node_name); - h ^= static_cast<std::size_t>(pos); + h = Hash64CombineUnordered(Hash64(node_name), h); + h = Hash64CombineUnordered(std::hash<int>()(pos), h); } for (const auto& attr : node.attr()) { - h ^= std::hash<string>{}(attr.first); - string tmp; - attr.second.AppendToString(&tmp); - h ^= std::hash<string>{}(tmp); + h = Hash64CombineUnordered(Hash64(attr.first), h); + h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h); } return h; } @@ -1839,17 +1872,8 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { } for (const auto& attr1 : node1.attr()) { auto it = node2.attr().find(attr1.first); - if (it == node2.attr().end()) { - return false; - } - const auto& attr2 = *it; - string val1; - attr1.second.AppendToString(&val1); - string val2; - attr2.second.AppendToString(&val2); - if (val1 != val2) { - return false; - } + if (it == node2.attr().end()) return false; + if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false; } return true; @@ -2233,6 +2257,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_square_node->set_input(i - 1, new_square_node->input(i)); } new_square_node->mutable_input()->RemoveLast(); + for (const string& input : new_square_node->input()) { + node_map_->AddOutput(NodeName(input), new_square_node->name()); + } return new_square_node->name(); } } @@ -2398,7 +2425,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.minimize_broadcasts && can_use_shapes) pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext); if (options_.remove_identity_transpose && can_use_shapes) - pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext); + pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext, opt_level_); if (options_.remove_redundant_bitcast) pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext); if (options_.remove_redundant_cast) @@ -2491,7 +2518,8 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_)); graph_properties_.reset(new GraphProperties(optimized_item)); - const Status status = graph_properties_->InferStatically(false); + const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; + const Status status = graph_properties_->InferStatically(assume_valid_feeds); const bool can_use_shapes = status.ok(); if (!can_use_shapes) { VLOG(1) << "Shape inference failed." << status.error_message(); |