diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-02 16:04:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-02 16:52:33 -0700 |
commit | 30927ec6b625121bae1b89b07f9faeaebaed321f (patch) | |
tree | 0f54ab601134eb818ae72eb032286034245cb218 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 49f2afe21e3cada8951205d00e877c873a33754c (diff) |
Mark all nodes processed by AddOpsRewrite/MinBCast stages with a tag.
PiperOrigin-RevId: 195167597
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 77 |
1 files changed, 44 insertions, 33 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index bf59b25449..d6510ba681 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -49,6 +50,12 @@ namespace tensorflow { namespace grappler { namespace { +// Mark nodes created or optimized by a stage with a tag. +constexpr char kAddOpsRewriteTag[] = + "_grappler:ArithmeticOptimizer:AddOpsRewriteStage"; +constexpr char kMinimizeBroadcastsTag[] = + "_grappler:ArithmeticOptimizer:MinimizeBroadcasts"; + // Extract values from a Const op to `values`. Returns true if succeeds. template <typename T> bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) { @@ -142,18 +149,6 @@ bool MaybeAddControlInput(const string& new_input, NodeDef* node, return !already_exists; } -int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph, - NodeMap* node_map) { - int num_copied = 0; - for (const string& input : from.input()) { - if (IsControlInput(input) && - MaybeAddControlInput(input, to, graph, node_map)) { - ++num_copied; - } - } - return num_copied; -} - void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) { (*node->mutable_attr())[attr_name].set_type(dtype); } @@ -326,7 +321,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { explicit ArithmeticNodesGroupOptimizerStage( const string& name, const GraphOptimizerContext& ctx, const ArithmeticOptimizerContext ctx_ext) - : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {} + : ArithmeticOptimizerStage(name, ctx, ctx_ext) {} ~ArithmeticNodesGroupOptimizerStage() override = default; // Input name with a statically inferred shape from GraphProperties @@ -465,13 +460,16 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { return signature; } - void AddToOptimizedNodes(const NodeDef* node) { - optimized_nodes_.insert(node->name()); + void MarkWithTag(const StringPiece tag, NodeDef* node) { + AddNodeAttr(tag, true, node); } - void AddAllMembersToOptimizedNodes(const OptimizedNodesGroup& group) { - AddToOptimizedNodes(group.root_node); - for (const NodeDef* opt : group.optimized_nodes) AddToOptimizedNodes(opt); + void MarkAllMembersWithTag(const OptimizedNodesGroup& group, + const StringPiece tag) const { + AddNodeAttr(tag, true, group.root_node); + for (NodeDef* optimized_node : group.optimized_nodes) { + AddNodeAttr(tag, true, optimized_node); + } } bool IsOnTheSameDevice(const OptimizedNodesGroup& group, @@ -479,13 +477,19 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { return group.root_node->device() == node.device(); } - bool IsAlreadyOptimized(const NodeDef& node) const { - return optimized_nodes_.find(node.name()) != optimized_nodes_.end(); + bool IsInPreserveSet(const NodeDef& node) const { + return ctx().nodes_to_preserve->find(node.name()) != + ctx().nodes_to_preserve->end(); } - private: - // set of nodes already processed by this optimizer stage - std::unordered_set<string> optimized_nodes_; + bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const { + return HasNodeAttr(node, tag); + } + + bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1, + const StringPiece tag2) const { + return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2); + } }; // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the @@ -561,7 +565,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { if (!IsAdd(node) && !IsAddN(node)) { return false; } - if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) { + if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) { return false; } // TODO(ezhulenev): relax this condition for root node @@ -579,7 +583,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { << " num_inputs=" << group.inputs.size(); // Do not optimize any of the nodes that are part of this group. - AddAllMembersToOptimizedNodes(group); + MarkAllMembersWithTag(group, kAddOpsRewriteTag); // All new nodes will be placed under the scope of a root node. auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name()); @@ -688,7 +692,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { node->add_input(inputAndShape.input); } - AddToOptimizedNodes(node); + MarkWithTag(kAddOpsRewriteTag, node); return InputAndShape(node_name, shape); } @@ -705,14 +709,13 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { node->set_op("Add"); node->set_device(root_node.device()); (*node->mutable_attr())["T"].set_type(dtype); + node->add_input(left.input); + node->add_input(right.input); ctx().node_map->AddOutput(left.input, node_name); ctx().node_map->AddOutput(right.input, node_name); - node->add_input(left.input); - node->add_input(right.input); - - AddToOptimizedNodes(node); + MarkWithTag(kAddOpsRewriteTag, node); return InputAndShape( node_name, TensorShapeProto()); // shape is not important at this point } @@ -960,7 +963,9 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { bool IsSupported(const NodeDef* node) const override { if (!IsBinaryAssociative(*node)) return false; - if (IsAlreadyOptimized(*node)) return false; + + if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) + return false; // has a symbolically defined shape with broadcastable inputs OpInfo::TensorProperties properties; @@ -984,7 +989,11 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { if (!IsSameOp(group, node)) { return false; } - if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) { + if (IsInPreserveSet(node)) { + return false; + } + // Nodes optimized by AddOpsRewrite already have optimal broadcasts. + if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) { return false; } if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) { @@ -1019,7 +1028,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { << " num_optimized_nodes=" << group.optimized_nodes.size(); // Do not optimize any of the nodes that are part of this group. - AddAllMembersToOptimizedNodes(group); + MarkAllMembersWithTag(group, kMinimizeBroadcastsTag); if (CountUniqueShapes(group.inputs) <= 1) { VLOG(3) << "Skip min-bcast group with single unique shape"; @@ -1905,6 +1914,8 @@ void ArithmeticOptimizer::DedupComputations() { FeedsInPlaceOp(graph_view, *node)) { continue; } + VLOG(3) << "Remove duplicated node: node=" << node->name() + << " representative=" << rep->name(); const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name()); for (NodeDef* fanout : fanouts) { for (int i = 0; i < fanout->input_size(); ++i) { |