aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 16:04:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 16:52:33 -0700
commit30927ec6b625121bae1b89b07f9faeaebaed321f (patch)
tree0f54ab601134eb818ae72eb032286034245cb218 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent49f2afe21e3cada8951205d00e877c873a33754c (diff)
Mark all nodes processed by AddOpsRewrite/MinBCast stages with a tag.
PiperOrigin-RevId: 195167597
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc77
1 files changed, 44 insertions, 33 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index bf59b25449..d6510ba681 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -49,6 +50,12 @@ namespace tensorflow {
namespace grappler {
namespace {
+// Mark nodes created or optimized by a stage with a tag.
+constexpr char kAddOpsRewriteTag[] =
+ "_grappler:ArithmeticOptimizer:AddOpsRewriteStage";
+constexpr char kMinimizeBroadcastsTag[] =
+ "_grappler:ArithmeticOptimizer:MinimizeBroadcasts";
+
// Extract values from a Const op to `values`. Returns true if succeeds.
template <typename T>
bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
@@ -142,18 +149,6 @@ bool MaybeAddControlInput(const string& new_input, NodeDef* node,
return !already_exists;
}
-int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph,
- NodeMap* node_map) {
- int num_copied = 0;
- for (const string& input : from.input()) {
- if (IsControlInput(input) &&
- MaybeAddControlInput(input, to, graph, node_map)) {
- ++num_copied;
- }
- }
- return num_copied;
-}
-
void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
(*node->mutable_attr())[attr_name].set_type(dtype);
}
@@ -326,7 +321,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
explicit ArithmeticNodesGroupOptimizerStage(
const string& name, const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext ctx_ext)
- : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {}
+ : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
~ArithmeticNodesGroupOptimizerStage() override = default;
// Input name with a statically inferred shape from GraphProperties
@@ -465,13 +460,16 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
return signature;
}
- void AddToOptimizedNodes(const NodeDef* node) {
- optimized_nodes_.insert(node->name());
+ void MarkWithTag(const StringPiece tag, NodeDef* node) {
+ AddNodeAttr(tag, true, node);
}
- void AddAllMembersToOptimizedNodes(const OptimizedNodesGroup& group) {
- AddToOptimizedNodes(group.root_node);
- for (const NodeDef* opt : group.optimized_nodes) AddToOptimizedNodes(opt);
+ void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
+ const StringPiece tag) const {
+ AddNodeAttr(tag, true, group.root_node);
+ for (NodeDef* optimized_node : group.optimized_nodes) {
+ AddNodeAttr(tag, true, optimized_node);
+ }
}
bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
@@ -479,13 +477,19 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
return group.root_node->device() == node.device();
}
- bool IsAlreadyOptimized(const NodeDef& node) const {
- return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
+ bool IsInPreserveSet(const NodeDef& node) const {
+ return ctx().nodes_to_preserve->find(node.name()) !=
+ ctx().nodes_to_preserve->end();
}
- private:
- // set of nodes already processed by this optimizer stage
- std::unordered_set<string> optimized_nodes_;
+ bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
+ return HasNodeAttr(node, tag);
+ }
+
+ bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
+ const StringPiece tag2) const {
+ return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
+ }
};
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
@@ -561,7 +565,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
if (!IsAdd(node) && !IsAddN(node)) {
return false;
}
- if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) {
+ if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
return false;
}
// TODO(ezhulenev): relax this condition for root node
@@ -579,7 +583,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
<< " num_inputs=" << group.inputs.size();
// Do not optimize any of the nodes that are part of this group.
- AddAllMembersToOptimizedNodes(group);
+ MarkAllMembersWithTag(group, kAddOpsRewriteTag);
// All new nodes will be placed under the scope of a root node.
auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
@@ -688,7 +692,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
node->add_input(inputAndShape.input);
}
- AddToOptimizedNodes(node);
+ MarkWithTag(kAddOpsRewriteTag, node);
return InputAndShape(node_name, shape);
}
@@ -705,14 +709,13 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
node->set_op("Add");
node->set_device(root_node.device());
(*node->mutable_attr())["T"].set_type(dtype);
+ node->add_input(left.input);
+ node->add_input(right.input);
ctx().node_map->AddOutput(left.input, node_name);
ctx().node_map->AddOutput(right.input, node_name);
- node->add_input(left.input);
- node->add_input(right.input);
-
- AddToOptimizedNodes(node);
+ MarkWithTag(kAddOpsRewriteTag, node);
return InputAndShape(
node_name, TensorShapeProto()); // shape is not important at this point
}
@@ -960,7 +963,9 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
bool IsSupported(const NodeDef* node) const override {
if (!IsBinaryAssociative(*node)) return false;
- if (IsAlreadyOptimized(*node)) return false;
+
+ if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
+ return false;
// has a symbolically defined shape with broadcastable inputs
OpInfo::TensorProperties properties;
@@ -984,7 +989,11 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
if (!IsSameOp(group, node)) {
return false;
}
- if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) {
+ if (IsInPreserveSet(node)) {
+ return false;
+ }
+ // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
+ if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
return false;
}
if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
@@ -1019,7 +1028,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
<< " num_optimized_nodes=" << group.optimized_nodes.size();
// Do not optimize any of the nodes that are part of this group.
- AddAllMembersToOptimizedNodes(group);
+ MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
if (CountUniqueShapes(group.inputs) <= 1) {
VLOG(3) << "Skip min-bcast group with single unique shape";
@@ -1905,6 +1914,8 @@ void ArithmeticOptimizer::DedupComputations() {
FeedsInPlaceOp(graph_view, *node)) {
continue;
}
+ VLOG(3) << "Remove duplicated node: node=" << node->name()
+ << " representative=" << rep->name();
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {