aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 10:13:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 10:15:46 -0700
commitaed12f35e29924e43f191d42fdcc6f9e025a3a3e (patch)
tree2d4f0769a4b102e2cdf98f750ab083b65729cd49
parent57b491744fa685cffc27b0dc73647fa2f05c9b68 (diff)
Minimize broadcasts by rewriting a sub-tree of binary associative ops (Add, Mul).
PiperOrigin-RevId: 192145052
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc561
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc161
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h12
4 files changed, 568 insertions, 171 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index da8d677737..fa0f7c1c6e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -279,6 +279,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx_ext_(ctx_ext) {}
virtual ~ArithmeticOptimizerStage() = default;
+ protected:
// Simplification graph rewrite can create additional nodes that are inputs
// to final simplified node, they can be also added to the arithmetic
// optimizer queue for further optimization.
@@ -304,10 +305,176 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
}
private:
- // extened context required for ArithmeticOptimizer
+ // Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
};
+// Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
+// group of nodes from the optimized graph.
+//
+// * AddOpsRewrite:
+// Rewrite a group of Add/AddN with compact Add/AddN tree
+//
+// * MinimizeBroadcasts:
+// Rewrite a group of binary associative ops, reordering
+// inputs, to minimize the cost of broadcast
+class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ArithmeticNodesGroupOptimizerStage(
+ const string& name, const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext ctx_ext)
+ : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {}
+ ~ArithmeticNodesGroupOptimizerStage() override = default;
+
+ // Input name with a statically inferred shape from GraphProperties
+ struct InputAndShape {
+ InputAndShape(const string& input, const TensorShapeProto& shape)
+ : input(input), shape(shape) {}
+ string input;
+ TensorShapeProto shape;
+ };
+
+ // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
+ // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
+ // obtained by graph traversal, starting from a root node.
+ struct OptimizedNodesGroup {
+ NodeDef* root_node;
+ TensorShapeProto root_shape;
+ // Optimized nodes that will be updated or removed by rewrite
+ std::vector<NodeDef*> optimized_nodes;
+ // Inputs to optimized nodes
+ std::vector<InputAndShape> inputs;
+ };
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+
+ OptimizedNodesGroup group;
+ TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
+
+ if (!group.optimized_nodes.empty()) {
+ *simplified_node_name = RewriteOptimizedNodesGroup(group);
+ }
+
+ return Status::OK();
+ }
+
+ protected:
+ // Modify the optimized graph after nodes group was successfully identified
+ virtual string RewriteOptimizedNodesGroup(
+ const OptimizedNodesGroup& group) = 0;
+
+ // Check if input can become a part of current optimized nodes group.
+ virtual bool IsAbsorbableByOptimizedNodesGroup(
+ const OptimizedNodesGroup& group, const string& input) const = 0;
+
+ Status AbsorbInputByOptimizedNodesGroup(const string& input,
+ OptimizedNodesGroup* group) const {
+ NodeDef* node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &node));
+
+ if (IsAbsorbableByOptimizedNodesGroup(*group, input)) {
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input_i = node->input(i);
+ if (!IsControlInput(input)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
+ }
+ }
+ group->optimized_nodes.push_back(node);
+ } else {
+ // If node can't be absorbed, add it to OptimizedNodesGroup input
+ OpInfo::TensorProperties properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
+ group->inputs.emplace_back(input, properties.shape());
+ }
+ return Status::OK();
+ }
+
+ Status CreateOptimizedNodesGroup(NodeDef* root_node,
+ OptimizedNodesGroup* group) const {
+ OpInfo::TensorProperties root_node_output_properties;
+ TF_RETURN_IF_ERROR(
+ GetTensorProperties(root_node->name(), &root_node_output_properties));
+
+ group->root_node = root_node;
+ group->root_shape = root_node_output_properties.shape();
+
+ group->optimized_nodes.reserve(root_node->input_size());
+ for (int i = 0; i < root_node->input_size(); ++i) {
+ const string& input_i = root_node->input(i);
+ if (!IsControlInput(input_i)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
+ }
+ }
+
+ return Status::OK();
+ }
+
+ // Check if all inputs can be broadcasted to the same shape
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool HasAllInputsBroadcastableToShape(
+ const NodeDef& node, const OpInfo::TensorProperties& properties) const {
+ auto is_broadcastable = [this, &properties](const string& input) {
+ OpInfo::TensorProperties input_props;
+ Status has_input_properties = GetTensorProperties(input, &input_props);
+ return has_input_properties.ok() &&
+ ShapesBroadcastable(properties, input_props);
+ };
+ return std::all_of(node.input().begin(), node.input().end(),
+ is_broadcastable);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ string ShapeSignature(const TensorShapeProto& shape) const {
+ string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
+ for (int i = 0; i < shape.dim_size(); ++i)
+ strings::StrAppend(&signature, ":", shape.dim(i).size());
+ return signature;
+ }
+
+ void AddToOptimizedNodes(const NodeDef* node) {
+ optimized_nodes_.insert(node->name());
+ }
+
+ bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
+ const NodeDef& node) const {
+ return group.root_node->device() == node.device();
+ }
+
+ bool IsInPreserveSet(const NodeDef& node) const {
+ return ctx_.nodes_to_preserve->find(node.name()) !=
+ ctx_.nodes_to_preserve->end();
+ }
+
+ bool IsAlreadyOptimized(const NodeDef& node) const {
+ return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
+ }
+
+ private:
+ // set of nodes already processed by this optimizer stage
+ std::unordered_set<string> optimized_nodes_;
+};
+
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes.
//
@@ -335,110 +502,33 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
// x y w Add_3 AddN(x, y, q, e) z
// / \
// q e
-class AddOpsRewriteStage : public ArithmeticOptimizerStage {
+class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
public:
explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("AddOpsRewrite", ctx, ctx_ext),
- rewritten_nodes_() {}
-
+ : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
~AddOpsRewriteStage() override = default;
// Check if a node can become a root of AddOpsGroup
bool IsSupported(const NodeDef* node) const override {
- // check basic preconditions
- if (!IsRewritable(node)) {
- return false;
- }
+ if (!CanOptimize(node)) return false;
// shape must be symbolically defined and all inputs compatible with it
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(node->name(), &properties);
return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
- HasAllInputsOfBroadcastableShape(*node, properties);
- }
-
- Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
- AddOpsGroup group;
- TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
-
- if (!group.absorbed_nodes.empty()) {
- *simplified_node_name = RewriteAddOpsGroup(group);
- }
-
- return Status::OK();
- }
-
- private:
- // Input name with a statically inferred shape from GraphProperties
- struct InputAndShape {
- InputAndShape(const string& input, const TensorShapeProto& shape)
- : input(input), shape(shape) {}
- string input;
- TensorShapeProto shape;
- };
-
- // Holds together an add ops subgraph that we want to rewrite together.
- //
- // For the graph above the AddOpsGroup will be:
- // root_node: AddN_1
- // absorbed_nodes: [Add_1, Add_2]
- // input_nodes: [x, y, z, w, q, e]
- struct AddOpsGroup {
- const NodeDef* root_node;
- TensorShapeProto root_shape;
- // Add/AddN operations below the root level that were absorbed by this group
- std::vector<NodeDef*> absorbed_nodes;
- // Inputs of absorbed nodes that will be forwarded to optimized AddN ops
- std::vector<InputAndShape> inputs;
- };
-
- // Check if all inputs can be broadcasted to the same shape
- bool HasAllInputsOfBroadcastableShape(
- const NodeDef& node, const OpInfo::TensorProperties& properties) const {
- const AddOpsRewriteStage* self = this;
- return std::all_of(
- node.input().begin(), node.input().end(),
- [self, &properties](const string& input) {
- OpInfo::TensorProperties input_properties;
- Status has_input_properties =
- self->GetTensorProperties(input, &input_properties);
- return has_input_properties.ok() &&
- ShapesBroadcastable(properties, input_properties);
- });
- }
-
- // TODO(ezhulenev): use GraphRewriter?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): use GraphRewriter?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
+ HasAllInputsBroadcastableToShape(*node, properties);
}
- // Check if a node can be absorbed by current AddOpsGroup
- bool IsAbsorbableByAddOpsGroup(const string& name, const AddOpsGroup& group) {
+ protected:
+ // Check if a node can be absorbed by current OptimizedNodesGroup
+ bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
+ const string& input) const override {
NodeDef* node;
- Status node_status = GetInputNode(name, &node);
- if (!node_status.ok()) {
- return false;
- }
- // check basic preconditions
- if (!IsRewritable(node)) {
+ Status node_status = GetInputNode(input, &node);
+ if (!node_status.ok() || !CanOptimize(node)) return false;
+
+ if (!IsOnTheSameDevice(group, *node)) {
return false;
}
// with a single output data consumer (presumably if we reach this node from
@@ -447,102 +537,42 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
return false;
}
- // must be on the same device as a root node
- if (node->device() != group.root_node->device()) {
- return false;
- }
// All input shapes must be broadcastable to the node shape
OpInfo::TensorProperties properties;
- Status has_properties = GetTensorProperties(name, &properties);
+ Status has_properties = GetTensorProperties(input, &properties);
return has_properties.ok() &&
- HasAllInputsOfBroadcastableShape(*node, properties);
+ HasAllInputsBroadcastableToShape(*node, properties);
}
// Node requirements both for a root node and an absorbed node
- bool IsRewritable(const NodeDef* node) const {
- // only Add or AddN can be a root node
+ bool CanOptimize(const NodeDef* node) const {
// TODO(ezhulenev): check if AccumulateNV2 can be supported too
if (!IsAdd(*node) && !IsAddN(*node)) {
return false;
}
- // it must not be in a preserve set
- if (ctx_.nodes_to_preserve->find(node->name()) !=
- ctx_.nodes_to_preserve->end()) {
- return false;
- }
- // it must not be a node created or absorbed by previous iteration
- if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) {
+ if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
return false;
}
// it must not be created by this stage at any of previous optimization runs
if (str_util::StrContains(node->name(), stage_name_)) {
return false;
}
- // should not drive or be driven by control dependency
// TODO(ezhulenev): relax this condition for root node
return !(IsDrivenByControlDependency(*node) ||
DrivesControlDependency(*node));
}
- // Create an AddOpsGroup with a root in a given node
- Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
- OpInfo::TensorProperties root_node_output_properties;
- TF_RETURN_IF_ERROR(
- GetTensorProperties(root_node->name(), &root_node_output_properties));
-
- group->root_node = root_node;
- group->root_shape = root_node_output_properties.shape();
-
- group->absorbed_nodes.reserve(root_node->input_size());
- for (int i = 0; i < root_node->input_size(); ++i) {
- const string& input_i = root_node->input(i);
- if (!IsControlInput(input_i)) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
- }
- }
-
- return Status::OK();
- }
-
- Status AbsorbInputByAddOpsGroup(const string& input, AddOpsGroup* group) {
- NodeDef* node;
- TF_RETURN_IF_ERROR(GetInputNode(input, &node));
-
- if (IsAbsorbableByAddOpsGroup(input, *group)) {
- group->absorbed_nodes.push_back(node);
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input_i = node->input(i);
- if (!IsControlInput(input)) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
- }
- }
- } else {
- // If node can't be absorbed, add it to AddOpsGroup input
- OpInfo::TensorProperties properties;
- TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
- group->inputs.emplace_back(input, properties.shape());
- }
- return Status::OK();
- }
-
- // Rewrite an add ops group into a single AddN if all input shapes are
+ // Rewrite a group of add ops into a single AddN if all input shapes are
// symbolically equal. If not, create AddN for equal shapes first, and then
// build an Add tree, minimizing the cost of broadcasts.
- string RewriteAddOpsGroup(const AddOpsGroup& group) {
+ string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
// all new nodes will be placed under the scope of a root node
auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
- auto shape_sig = [](const TensorShapeProto& shape) {
- string name = strings::StrCat("r:", shape.dim_size(), ":d");
- for (int i = 0; i < shape.dim_size(); ++i)
- strings::StrAppend(&name, ":", shape.dim(i).size());
- return name;
- };
-
// Find what shapes are present in the inputs of absorbed nodes
std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
for (const auto& input : group.inputs) {
- shape_sig_to_inputs[shape_sig(input.shape)].push_back(input);
+ shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
}
// Collect all the shapes from representative elements
@@ -556,8 +586,6 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
string node_name = OptimizedNodeName(root_scope_and_name);
AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
group.inputs);
- // keep track of nodes that were created or absorbed as a part of rewrite
- rewritten_nodes_.insert(node_name);
return node_name;
}
@@ -586,7 +614,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// Prepare leaf AddN nodes for inputs of equal shape
for (int i = 0; i < shapes.size(); ++i) {
const auto node_name = leaf_node_name(i);
- const auto& inputs = shape_sig_to_inputs[shape_sig(shapes[i])];
+ const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
node_name, inputs));
}
@@ -637,7 +665,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
node->add_input(inputAndShape.input);
}
- rewritten_nodes_.insert(node_name);
+ AddToOptimizedNodes(node);
return InputAndShape(node_name, shape);
}
@@ -661,13 +689,10 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
node->add_input(left.input);
node->add_input(right.input);
- rewritten_nodes_.insert(node_name);
+ AddToOptimizedNodes(node);
return InputAndShape(
node_name, TensorShapeProto()); // shape is not important at this point
}
-
- // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite
- std::unordered_set<string> rewritten_nodes_;
};
// Use the commutativity and (left- and right-) distributive property of
@@ -693,7 +718,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
std::set<string> common_factors;
std::vector<string> ctrl_deps;
@@ -839,6 +864,201 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
std::unordered_set<string> rewritten_nodes_;
};
+// Binary associative ops can be re-ordered to minimize the number of broadcasts
+// and the size of a temporary tensors.
+//
+// Example: [a, c] - scalars, [b, d] - matrices
+// @ - binary associative op (Add or Mul)
+// @* - broadcast
+//
+// @ @*
+// / \ / \
+// @* @* -> @ @
+// / \ / \ / \ / \
+// a b c d a c b d
+class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
+ public:
+ explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
+ }
+ ~MinimizeBroadcasts() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ if (!IsBinaryAssociative(*node)) return false;
+
+ // has a symbolically defined shape with broadcastable inputs
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(node->name(), &properties);
+ return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
+ HasAllInputsBroadcastableToShape(*node, properties);
+ }
+
+ protected:
+ bool IsBinaryAssociative(const NodeDef& node) const {
+ return IsMul(node) || IsAdd(node);
+ }
+
+ bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
+ return group.root_node->op() == node.op();
+ }
+
+ // Check if a node can be absorbed by current OptimizedNodesGroup
+ bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
+ const string& input) const override {
+ NodeDef* node;
+ Status node_status = GetInputNode(input, &node);
+ if (!node_status.ok()) return false;
+
+ if (!IsSameOp(group, *node)) {
+ return false;
+ }
+ if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
+ return false;
+ }
+ if (IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)) {
+ return false;
+ }
+ if (!IsOnTheSameDevice(group, *node)) {
+ return false;
+ }
+ // Optimized nodes updated in place, and that would break the graph, if the
+ // node has multiple output consumers
+ if (NumNonControlOutputs(*node, *ctx_.node_map) != 1) {
+ return false;
+ }
+ // All input shapes must be broadcastable to the node shape
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(input, &properties);
+ return has_properties.ok() &&
+ HasAllInputsBroadcastableToShape(*node, properties);
+ }
+
+ std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
+ std::set<string> sigs;
+ for (const auto& ias : inputs) {
+ sigs.insert(ShapeSignature(ias.shape));
+ }
+ return sigs.size();
+ }
+
+ string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
+ if (CountUniqueShapes(group.inputs) <= 1) {
+ // nothing to optimize when all shapes are the same
+ return group.root_node->name();
+ }
+
+ auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
+ auto num_inputs = group.inputs.size();
+ CHECK_EQ(num_nodes, num_inputs - 1)
+ << "Can't build a tree with " << num_inputs << " inputs, using "
+ << num_nodes << "binary op nodes.";
+
+ std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
+ std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
+ group.optimized_nodes.end());
+
+ // sort inputs by it's shape from smallest to largest
+ std::stable_sort(add_ops.begin(), add_ops.end(),
+ [](const InputAndShape& lhs, const InputAndShape& rhs) {
+ return CompareSymbolicallyShapedTensorSizes(lhs.shape,
+ rhs.shape);
+ });
+
+ // If there is an odd number of inputs, last one is the largest, and we want
+ // to attach it to the root node, to build a well balanced tree.
+ std::deque<InputAndShape> add_ops_leftover;
+ if (add_ops.size() % 2 != 0) {
+ add_ops_leftover.push_back(add_ops.back());
+ add_ops.pop_back();
+ }
+
+ // At this point it's guaranteed that add_ops have even number of inputs.
+ do {
+ const InputAndShape lhs = add_ops.front();
+ add_ops.pop_front();
+ const InputAndShape rhs = add_ops.front();
+ add_ops.pop_front();
+
+ NodeDef* node;
+ if (!optimized_nodes.empty()) {
+ // re-purpose optimized nodes to build a new tree
+ node = optimized_nodes.front();
+ optimized_nodes.pop_front();
+ } else {
+ // or use root node if none optimized nodes left
+ node = group.root_node;
+ }
+ InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
+
+ // Pushing updated node to the back of a deque will create a wide and
+ // short tree, pushing to the front will create a tall tree. We prefer to
+ // get a wide tree, it minimizes the potential number of temporary tensors
+ // required to keep in memory, though sometimes we can go up to prevent
+ // propagating a brodcast from leaves to the root. Example:
+ //
+ // inputs: [s, s, s, M] (s - scalar, M - matrix)
+ // @* - op with broadcast
+ //
+ // (only push_back) @* (push_front first op)
+ // / \
+ // @* @ M
+ // / \ / \
+ // @ @* -> @ s
+ // / \ / \ / \
+ // s s s M s s
+ if (add_ops.size() >= 2 &&
+ CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
+ add_ops.at(1).shape)) {
+ add_ops.push_front(updated_node);
+ } else {
+ add_ops.push_back(updated_node);
+ }
+ } while (add_ops.size() > 1);
+ CHECK_EQ(1, add_ops.size());
+
+ // attach the largest tensor to the root op
+ if (!add_ops_leftover.empty()) {
+ const InputAndShape lhs = add_ops.front();
+ add_ops.pop_front();
+ const InputAndShape rhs = add_ops_leftover.front();
+ InputAndShape updated_node =
+ UpdateInputs(lhs.input, rhs.input, group.root_node);
+ add_ops.push_back(updated_node);
+ }
+
+ return add_ops.front().input;
+ }
+
+ InputAndShape UpdateInputs(const string& input_0, const string& input_1,
+ NodeDef* node) {
+ string old_input_0 = node->input(0);
+ string old_input_1 = node->input(1);
+
+ // Update inputs only if they changed
+ if (old_input_0 != input_0 || old_input_1 != input_1) {
+ node->set_input(0, input_0);
+ node->set_input(1, input_1);
+ // Invalidate node properties (shape)
+ ctx_.graph_properties->ClearOutputProperties(node->name());
+ ctx_.graph_properties->ClearInputProperties(node->name());
+ // Update the node map
+ ctx_.node_map->RemoveOutput(NodeName(old_input_0), node->name());
+ ctx_.node_map->RemoveOutput(NodeName(old_input_1), node->name());
+ ctx_.node_map->AddOutput(NodeName(input_0), node->name());
+ ctx_.node_map->AddOutput(NodeName(input_1), node->name());
+ // Add updated node to optimization queue
+ AddToOptimizationQueue(node);
+ }
+
+ // Do not add updated node to any other group
+ AddToOptimizedNodes(node);
+
+ TensorShapeProto shape; // shape is not important at this point
+ return InputAndShape(node->name(), shape);
+ }
+};
+
// Removes inverse transpose nodes
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
public:
@@ -854,7 +1074,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
// TODO(rmlarsen): Forward control dependencies on the bypassed
// transpose nodes.
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
@@ -943,7 +1163,7 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
// Bypass Bitcast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
@@ -981,7 +1201,8 @@ class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+
// Bypass Cast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
*simplified_node_name = node->input(0);
@@ -1678,6 +1899,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
+ if (options_.minimize_broadcasts && can_use_shapes)
+ pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
if (options_.remove_identity_transpose && can_use_shapes)
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
if (options_.remove_redundant_bitcast)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 39b89dedba..c0fe8839ca 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -59,6 +59,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool enable_try_simplify_and_replace = true;
bool combine_add_to_addn = false;
bool hoist_common_factor_out_of_aggregation = true;
+ bool minimize_broadcasts = false;
bool remove_identity_transpose = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
@@ -69,10 +70,10 @@ class ArithmeticOptimizer : public GraphOptimizer {
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
ArithmeticOptimizerOptions options;
- // TODO(ezhulenev): enable combine_add_to_addn by default after 1.8
- // release cut
+ // TODO(ezhulenev): enable by default after 1.8 release cut
if (opt_level == RewriterConfig::AGGRESSIVE) {
options.combine_add_to_addn = true;
+ options.minimize_broadcasts = true;
}
return options;
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e117341ba3..9677175d2e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -93,6 +93,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
+ options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
@@ -113,6 +114,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
}
+ void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.minimize_broadcasts = true;
+ }
+
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_identity_transpose = true;
@@ -1841,5 +1847,160 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
EXPECT_EQ(5, found);
}
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // * *
+ // / \ / \
+ // * c --> * b
+ // / \ / \
+ // a b a c
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("c", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("b", mul2_node->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ auto d = ops::Variable(s.WithOpName("d"), {32}, DT_FLOAT);
+ auto e = ops::Variable(s.WithOpName("e"), {32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
+ auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
+ auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur: Graph is "flattened" and
+ // largest shape pushed to the top.
+ //
+ // *
+ // / \
+ // * e *
+ // / \ / \
+ // * d * b
+ // / \ / \
+ // * c --> * *
+ // / \ / \ / \
+ // a b a c d e
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("c", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("d", mul2_node->input(0));
+ EXPECT_EQ("e", mul2_node->input(1));
+
+ const NodeDef* mul3_node = node_map.GetNode("mul3");
+ ASSERT_NE(mul3_node, nullptr);
+ EXPECT_EQ("mul1", mul3_node->input(0));
+ EXPECT_EQ("mul2", mul3_node->input(1));
+
+ const NodeDef* mul4_node = node_map.GetNode("mul4");
+ ASSERT_NE(mul4_node, nullptr);
+ EXPECT_EQ("mul3", mul4_node->input(0));
+ EXPECT_EQ("b", mul4_node->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // [a, b, c] - scalars, [d] - matrix
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
+ auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // *
+ // / \
+ // * * D
+ // / \ / \
+ // * * -> * c
+ // / \ / \ / \
+ // a b c D a b
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("b", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("c", mul2_node->input(1));
+
+ const NodeDef* mul3_node = node_map.GetNode("mul3");
+ ASSERT_NE(mul3_node, nullptr);
+ EXPECT_EQ("D", mul3_node->input(0));
+ EXPECT_EQ("mul2", mul3_node->input(1));
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 7ed0474861..072f772946 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -134,6 +134,18 @@ class GraphOptimizerStage {
// and remove template parameter.
virtual Status TrySimplify(NodeDef* node, Result* result) = 0;
+ // Return InvalidArgumentError if node is not supported by the optimizer
+ // stage.
+ // TODO(ezhulenev): make this check part of non-virtual public API
+ // (TrySimplify), and make virtual implementation protected.
+ Status EnsureNodeIsSupported(const NodeDef* node) const {
+ return IsSupported(node)
+ ? Status::OK()
+ : errors::InvalidArgument(
+ "Node ", node->name(), " is not supported by optimizer ",
+ optimizer_name_, " and stage ", stage_name_);
+ }
+
// Get a name for a new node, created by this stage, based on one or multiple
// nodes of an original graph.
const string OptimizedNodeName(const NodeScopeAndName& node) const {