diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-09 10:13:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-09 10:15:46 -0700 |
commit | aed12f35e29924e43f191d42fdcc6f9e025a3a3e (patch) | |
tree | 2d4f0769a4b102e2cdf98f750ab083b65729cd49 | |
parent | 57b491744fa685cffc27b0dc73647fa2f05c9b68 (diff) |
Minimize broadcasts by rewriting a sub-tree of binary associative ops (Add, Mul).
PiperOrigin-RevId: 192145052
4 files changed, 568 insertions, 171 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index da8d677737..fa0f7c1c6e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -279,6 +279,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { ctx_ext_(ctx_ext) {} virtual ~ArithmeticOptimizerStage() = default; + protected: // Simplification graph rewrite can create additional nodes that are inputs // to final simplified node, they can be also added to the arithmetic // optimizer queue for further optimization. @@ -304,10 +305,176 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { } private: - // extened context required for ArithmeticOptimizer + // Extended context required for ArithmeticOptimizer. const ArithmeticOptimizerContext ctx_ext_; }; +// Subtype of ArithmeticOptimizerStage that does optimization by rewriting a +// group of nodes from the optimized graph. +// +// * AddOpsRewrite: +// Rewrite a group of Add/AddN with compact Add/AddN tree +// +// * MinimizeBroadcasts: +// Rewrite a group of binary associative ops, reordering +// inputs, to minimize the cost of broadcast +class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { + public: + explicit ArithmeticNodesGroupOptimizerStage( + const string& name, const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext ctx_ext) + : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {} + ~ArithmeticNodesGroupOptimizerStage() override = default; + + // Input name with a statically inferred shape from GraphProperties + struct InputAndShape { + InputAndShape(const string& input, const TensorShapeProto& shape) + : input(input), shape(shape) {} + string input; + TensorShapeProto shape; + }; + + // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g. + // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is + // obtained by graph traversal, starting from a root node. + struct OptimizedNodesGroup { + NodeDef* root_node; + TensorShapeProto root_shape; + // Optimized nodes that will be updated or removed by rewrite + std::vector<NodeDef*> optimized_nodes; + // Inputs to optimized nodes + std::vector<InputAndShape> inputs; + }; + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); + + OptimizedNodesGroup group; + TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group)); + + if (!group.optimized_nodes.empty()) { + *simplified_node_name = RewriteOptimizedNodesGroup(group); + } + + return Status::OK(); + } + + protected: + // Modify the optimized graph after nodes group was successfully identified + virtual string RewriteOptimizedNodesGroup( + const OptimizedNodesGroup& group) = 0; + + // Check if input can become a part of current optimized nodes group. + virtual bool IsAbsorbableByOptimizedNodesGroup( + const OptimizedNodesGroup& group, const string& input) const = 0; + + Status AbsorbInputByOptimizedNodesGroup(const string& input, + OptimizedNodesGroup* group) const { + NodeDef* node; + TF_RETURN_IF_ERROR(GetInputNode(input, &node)); + + if (IsAbsorbableByOptimizedNodesGroup(*group, input)) { + for (int i = 0; i < node->input_size(); ++i) { + const string& input_i = node->input(i); + if (!IsControlInput(input)) { + TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group)); + } + } + group->optimized_nodes.push_back(node); + } else { + // If node can't be absorbed, add it to OptimizedNodesGroup input + OpInfo::TensorProperties properties; + TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties)); + group->inputs.emplace_back(input, properties.shape()); + } + return Status::OK(); + } + + Status CreateOptimizedNodesGroup(NodeDef* root_node, + OptimizedNodesGroup* group) const { + OpInfo::TensorProperties root_node_output_properties; + TF_RETURN_IF_ERROR( + GetTensorProperties(root_node->name(), &root_node_output_properties)); + + group->root_node = root_node; + group->root_shape = root_node_output_properties.shape(); + + group->optimized_nodes.reserve(root_node->input_size()); + for (int i = 0; i < root_node->input_size(); ++i) { + const string& input_i = root_node->input(i); + if (!IsControlInput(input_i)) { + TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group)); + } + } + + return Status::OK(); + } + + // Check if all inputs can be broadcasted to the same shape + // TODO(ezhulenev): move to GraphOptimizerStage? + bool HasAllInputsBroadcastableToShape( + const NodeDef& node, const OpInfo::TensorProperties& properties) const { + auto is_broadcastable = [this, &properties](const string& input) { + OpInfo::TensorProperties input_props; + Status has_input_properties = GetTensorProperties(input, &input_props); + return has_input_properties.ok() && + ShapesBroadcastable(properties, input_props); + }; + return std::all_of(node.input().begin(), node.input().end(), + is_broadcastable); + } + + // TODO(ezhulenev): move to GraphOptimizerStage? + bool IsDrivenByControlDependency(const NodeDef& node) const { + return std::any_of(node.input().begin(), node.input().end(), + IsControlInput); + } + + // TODO(ezhulenev): move to GraphOptimizerStage? + bool DrivesControlDependency(const NodeDef& node) const { + int position; + for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) { + for (int i = 0; i < output->input_size(); ++i) { + auto input = output->input(i); + string name = ParseNodeName(input, &position); + if (name == node.name() && /*control input*/ position < 0) { + return true; + } + } + } + return false; + } + + string ShapeSignature(const TensorShapeProto& shape) const { + string signature = strings::StrCat("rank:", shape.dim_size(), ":dim"); + for (int i = 0; i < shape.dim_size(); ++i) + strings::StrAppend(&signature, ":", shape.dim(i).size()); + return signature; + } + + void AddToOptimizedNodes(const NodeDef* node) { + optimized_nodes_.insert(node->name()); + } + + bool IsOnTheSameDevice(const OptimizedNodesGroup& group, + const NodeDef& node) const { + return group.root_node->device() == node.device(); + } + + bool IsInPreserveSet(const NodeDef& node) const { + return ctx_.nodes_to_preserve->find(node.name()) != + ctx_.nodes_to_preserve->end(); + } + + bool IsAlreadyOptimized(const NodeDef& node) const { + return optimized_nodes_.find(node.name()) != optimized_nodes_.end(); + } + + private: + // set of nodes already processed by this optimizer stage + std::unordered_set<string> optimized_nodes_; +}; + // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the // original inputs of absorbed nodes. // @@ -335,110 +502,33 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { // x y w Add_3 AddN(x, y, q, e) z // / \ // q e -class AddOpsRewriteStage : public ArithmeticOptimizerStage { +class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { public: explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx, const ArithmeticOptimizerContext& ctx_ext) - : ArithmeticOptimizerStage("AddOpsRewrite", ctx, ctx_ext), - rewritten_nodes_() {} - + : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {} ~AddOpsRewriteStage() override = default; // Check if a node can become a root of AddOpsGroup bool IsSupported(const NodeDef* node) const override { - // check basic preconditions - if (!IsRewritable(node)) { - return false; - } + if (!CanOptimize(node)) return false; // shape must be symbolically defined and all inputs compatible with it OpInfo::TensorProperties properties; Status has_properties = GetTensorProperties(node->name(), &properties); return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) && - HasAllInputsOfBroadcastableShape(*node, properties); - } - - Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)); - AddOpsGroup group; - TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group)); - - if (!group.absorbed_nodes.empty()) { - *simplified_node_name = RewriteAddOpsGroup(group); - } - - return Status::OK(); - } - - private: - // Input name with a statically inferred shape from GraphProperties - struct InputAndShape { - InputAndShape(const string& input, const TensorShapeProto& shape) - : input(input), shape(shape) {} - string input; - TensorShapeProto shape; - }; - - // Holds together an add ops subgraph that we want to rewrite together. - // - // For the graph above the AddOpsGroup will be: - // root_node: AddN_1 - // absorbed_nodes: [Add_1, Add_2] - // input_nodes: [x, y, z, w, q, e] - struct AddOpsGroup { - const NodeDef* root_node; - TensorShapeProto root_shape; - // Add/AddN operations below the root level that were absorbed by this group - std::vector<NodeDef*> absorbed_nodes; - // Inputs of absorbed nodes that will be forwarded to optimized AddN ops - std::vector<InputAndShape> inputs; - }; - - // Check if all inputs can be broadcasted to the same shape - bool HasAllInputsOfBroadcastableShape( - const NodeDef& node, const OpInfo::TensorProperties& properties) const { - const AddOpsRewriteStage* self = this; - return std::all_of( - node.input().begin(), node.input().end(), - [self, &properties](const string& input) { - OpInfo::TensorProperties input_properties; - Status has_input_properties = - self->GetTensorProperties(input, &input_properties); - return has_input_properties.ok() && - ShapesBroadcastable(properties, input_properties); - }); - } - - // TODO(ezhulenev): use GraphRewriter? - bool IsDrivenByControlDependency(const NodeDef& node) const { - return std::any_of(node.input().begin(), node.input().end(), - IsControlInput); - } - - // TODO(ezhulenev): use GraphRewriter? - bool DrivesControlDependency(const NodeDef& node) const { - int position; - for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) { - for (int i = 0; i < output->input_size(); ++i) { - auto input = output->input(i); - string name = ParseNodeName(input, &position); - if (name == node.name() && /*control input*/ position < 0) { - return true; - } - } - } - return false; + HasAllInputsBroadcastableToShape(*node, properties); } - // Check if a node can be absorbed by current AddOpsGroup - bool IsAbsorbableByAddOpsGroup(const string& name, const AddOpsGroup& group) { + protected: + // Check if a node can be absorbed by current OptimizedNodesGroup + bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group, + const string& input) const override { NodeDef* node; - Status node_status = GetInputNode(name, &node); - if (!node_status.ok()) { - return false; - } - // check basic preconditions - if (!IsRewritable(node)) { + Status node_status = GetInputNode(input, &node); + if (!node_status.ok() || !CanOptimize(node)) return false; + + if (!IsOnTheSameDevice(group, *node)) { return false; } // with a single output data consumer (presumably if we reach this node from @@ -447,102 +537,42 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) { return false; } - // must be on the same device as a root node - if (node->device() != group.root_node->device()) { - return false; - } // All input shapes must be broadcastable to the node shape OpInfo::TensorProperties properties; - Status has_properties = GetTensorProperties(name, &properties); + Status has_properties = GetTensorProperties(input, &properties); return has_properties.ok() && - HasAllInputsOfBroadcastableShape(*node, properties); + HasAllInputsBroadcastableToShape(*node, properties); } // Node requirements both for a root node and an absorbed node - bool IsRewritable(const NodeDef* node) const { - // only Add or AddN can be a root node + bool CanOptimize(const NodeDef* node) const { // TODO(ezhulenev): check if AccumulateNV2 can be supported too if (!IsAdd(*node) && !IsAddN(*node)) { return false; } - // it must not be in a preserve set - if (ctx_.nodes_to_preserve->find(node->name()) != - ctx_.nodes_to_preserve->end()) { - return false; - } - // it must not be a node created or absorbed by previous iteration - if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) { + if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) { return false; } // it must not be created by this stage at any of previous optimization runs if (str_util::StrContains(node->name(), stage_name_)) { return false; } - // should not drive or be driven by control dependency // TODO(ezhulenev): relax this condition for root node return !(IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)); } - // Create an AddOpsGroup with a root in a given node - Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) { - OpInfo::TensorProperties root_node_output_properties; - TF_RETURN_IF_ERROR( - GetTensorProperties(root_node->name(), &root_node_output_properties)); - - group->root_node = root_node; - group->root_shape = root_node_output_properties.shape(); - - group->absorbed_nodes.reserve(root_node->input_size()); - for (int i = 0; i < root_node->input_size(); ++i) { - const string& input_i = root_node->input(i); - if (!IsControlInput(input_i)) { - TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group)); - } - } - - return Status::OK(); - } - - Status AbsorbInputByAddOpsGroup(const string& input, AddOpsGroup* group) { - NodeDef* node; - TF_RETURN_IF_ERROR(GetInputNode(input, &node)); - - if (IsAbsorbableByAddOpsGroup(input, *group)) { - group->absorbed_nodes.push_back(node); - for (int i = 0; i < node->input_size(); ++i) { - const string& input_i = node->input(i); - if (!IsControlInput(input)) { - TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group)); - } - } - } else { - // If node can't be absorbed, add it to AddOpsGroup input - OpInfo::TensorProperties properties; - TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties)); - group->inputs.emplace_back(input, properties.shape()); - } - return Status::OK(); - } - - // Rewrite an add ops group into a single AddN if all input shapes are + // Rewrite a group of add ops into a single AddN if all input shapes are // symbolically equal. If not, create AddN for equal shapes first, and then // build an Add tree, minimizing the cost of broadcasts. - string RewriteAddOpsGroup(const AddOpsGroup& group) { + string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override { // all new nodes will be placed under the scope of a root node auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name()); - auto shape_sig = [](const TensorShapeProto& shape) { - string name = strings::StrCat("r:", shape.dim_size(), ":d"); - for (int i = 0; i < shape.dim_size(); ++i) - strings::StrAppend(&name, ":", shape.dim(i).size()); - return name; - }; - // Find what shapes are present in the inputs of absorbed nodes std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs; for (const auto& input : group.inputs) { - shape_sig_to_inputs[shape_sig(input.shape)].push_back(input); + shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input); } // Collect all the shapes from representative elements @@ -556,8 +586,6 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { string node_name = OptimizedNodeName(root_scope_and_name); AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name, group.inputs); - // keep track of nodes that were created or absorbed as a part of rewrite - rewritten_nodes_.insert(node_name); return node_name; } @@ -586,7 +614,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { // Prepare leaf AddN nodes for inputs of equal shape for (int i = 0; i < shapes.size(); ++i) { const auto node_name = leaf_node_name(i); - const auto& inputs = shape_sig_to_inputs[shape_sig(shapes[i])]; + const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])]; add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name, inputs)); } @@ -637,7 +665,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { node->add_input(inputAndShape.input); } - rewritten_nodes_.insert(node_name); + AddToOptimizedNodes(node); return InputAndShape(node_name, shape); } @@ -661,13 +689,10 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { node->add_input(left.input); node->add_input(right.input); - rewritten_nodes_.insert(node_name); + AddToOptimizedNodes(node); return InputAndShape( node_name, TensorShapeProto()); // shape is not important at this point } - - // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite - std::unordered_set<string> rewritten_nodes_; }; // Use the commutativity and (left- and right-) distributive property of @@ -693,7 +718,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)); + TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); std::set<string> common_factors; std::vector<string> ctrl_deps; @@ -839,6 +864,201 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { std::unordered_set<string> rewritten_nodes_; }; +// Binary associative ops can be re-ordered to minimize the number of broadcasts +// and the size of a temporary tensors. +// +// Example: [a, c] - scalars, [b, d] - matrices +// @ - binary associative op (Add or Mul) +// @* - broadcast +// +// @ @* +// / \ / \ +// @* @* -> @ @ +// / \ / \ / \ / \ +// a b c d a c b d +class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { + public: + explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) { + } + ~MinimizeBroadcasts() override = default; + + bool IsSupported(const NodeDef* node) const override { + if (!IsBinaryAssociative(*node)) return false; + + // has a symbolically defined shape with broadcastable inputs + OpInfo::TensorProperties properties; + Status has_properties = GetTensorProperties(node->name(), &properties); + return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) && + HasAllInputsBroadcastableToShape(*node, properties); + } + + protected: + bool IsBinaryAssociative(const NodeDef& node) const { + return IsMul(node) || IsAdd(node); + } + + bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const { + return group.root_node->op() == node.op(); + } + + // Check if a node can be absorbed by current OptimizedNodesGroup + bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group, + const string& input) const override { + NodeDef* node; + Status node_status = GetInputNode(input, &node); + if (!node_status.ok()) return false; + + if (!IsSameOp(group, *node)) { + return false; + } + if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) { + return false; + } + if (IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)) { + return false; + } + if (!IsOnTheSameDevice(group, *node)) { + return false; + } + // Optimized nodes updated in place, and that would break the graph, if the + // node has multiple output consumers + if (NumNonControlOutputs(*node, *ctx_.node_map) != 1) { + return false; + } + // All input shapes must be broadcastable to the node shape + OpInfo::TensorProperties properties; + Status has_properties = GetTensorProperties(input, &properties); + return has_properties.ok() && + HasAllInputsBroadcastableToShape(*node, properties); + } + + std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) { + std::set<string> sigs; + for (const auto& ias : inputs) { + sigs.insert(ShapeSignature(ias.shape)); + } + return sigs.size(); + } + + string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override { + if (CountUniqueShapes(group.inputs) <= 1) { + // nothing to optimize when all shapes are the same + return group.root_node->name(); + } + + auto num_nodes = /*root*/ 1 + group.optimized_nodes.size(); + auto num_inputs = group.inputs.size(); + CHECK_EQ(num_nodes, num_inputs - 1) + << "Can't build a tree with " << num_inputs << " inputs, using " + << num_nodes << "binary op nodes."; + + std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end()); + std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(), + group.optimized_nodes.end()); + + // sort inputs by it's shape from smallest to largest + std::stable_sort(add_ops.begin(), add_ops.end(), + [](const InputAndShape& lhs, const InputAndShape& rhs) { + return CompareSymbolicallyShapedTensorSizes(lhs.shape, + rhs.shape); + }); + + // If there is an odd number of inputs, last one is the largest, and we want + // to attach it to the root node, to build a well balanced tree. + std::deque<InputAndShape> add_ops_leftover; + if (add_ops.size() % 2 != 0) { + add_ops_leftover.push_back(add_ops.back()); + add_ops.pop_back(); + } + + // At this point it's guaranteed that add_ops have even number of inputs. + do { + const InputAndShape lhs = add_ops.front(); + add_ops.pop_front(); + const InputAndShape rhs = add_ops.front(); + add_ops.pop_front(); + + NodeDef* node; + if (!optimized_nodes.empty()) { + // re-purpose optimized nodes to build a new tree + node = optimized_nodes.front(); + optimized_nodes.pop_front(); + } else { + // or use root node if none optimized nodes left + node = group.root_node; + } + InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node); + + // Pushing updated node to the back of a deque will create a wide and + // short tree, pushing to the front will create a tall tree. We prefer to + // get a wide tree, it minimizes the potential number of temporary tensors + // required to keep in memory, though sometimes we can go up to prevent + // propagating a brodcast from leaves to the root. Example: + // + // inputs: [s, s, s, M] (s - scalar, M - matrix) + // @* - op with broadcast + // + // (only push_back) @* (push_front first op) + // / \ + // @* @ M + // / \ / \ + // @ @* -> @ s + // / \ / \ / \ + // s s s M s s + if (add_ops.size() >= 2 && + CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape, + add_ops.at(1).shape)) { + add_ops.push_front(updated_node); + } else { + add_ops.push_back(updated_node); + } + } while (add_ops.size() > 1); + CHECK_EQ(1, add_ops.size()); + + // attach the largest tensor to the root op + if (!add_ops_leftover.empty()) { + const InputAndShape lhs = add_ops.front(); + add_ops.pop_front(); + const InputAndShape rhs = add_ops_leftover.front(); + InputAndShape updated_node = + UpdateInputs(lhs.input, rhs.input, group.root_node); + add_ops.push_back(updated_node); + } + + return add_ops.front().input; + } + + InputAndShape UpdateInputs(const string& input_0, const string& input_1, + NodeDef* node) { + string old_input_0 = node->input(0); + string old_input_1 = node->input(1); + + // Update inputs only if they changed + if (old_input_0 != input_0 || old_input_1 != input_1) { + node->set_input(0, input_0); + node->set_input(1, input_1); + // Invalidate node properties (shape) + ctx_.graph_properties->ClearOutputProperties(node->name()); + ctx_.graph_properties->ClearInputProperties(node->name()); + // Update the node map + ctx_.node_map->RemoveOutput(NodeName(old_input_0), node->name()); + ctx_.node_map->RemoveOutput(NodeName(old_input_1), node->name()); + ctx_.node_map->AddOutput(NodeName(input_0), node->name()); + ctx_.node_map->AddOutput(NodeName(input_1), node->name()); + // Add updated node to optimization queue + AddToOptimizationQueue(node); + } + + // Do not add updated node to any other group + AddToOptimizedNodes(node); + + TensorShapeProto shape; // shape is not important at this point + return InputAndShape(node->name(), shape); + } +}; + // Removes inverse transpose nodes class RemoveIdentityTranspose : public ArithmeticOptimizerStage { public: @@ -854,7 +1074,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { // TODO(rmlarsen): Forward control dependencies on the bypassed // transpose nodes. Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)); + TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); @@ -943,7 +1163,7 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)); + TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); // Bypass Bitcast whose source type and destination type are equal. if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { @@ -981,7 +1201,8 @@ class RemoveRedundantCastStage : public ArithmeticOptimizerStage { bool IsSupported(const NodeDef* node) const override { return IsCast(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)); + TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); + // Bypass Cast whose source type and destination type are equal. if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { *simplified_node_name = node->input(0); @@ -1678,6 +1899,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext); if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes) pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext); + if (options_.minimize_broadcasts && can_use_shapes) + pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext); if (options_.remove_identity_transpose && can_use_shapes) pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext); if (options_.remove_redundant_bitcast) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 39b89dedba..c0fe8839ca 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -59,6 +59,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool enable_try_simplify_and_replace = true; bool combine_add_to_addn = false; bool hoist_common_factor_out_of_aggregation = true; + bool minimize_broadcasts = false; bool remove_identity_transpose = true; bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; @@ -69,10 +70,10 @@ class ArithmeticOptimizer : public GraphOptimizer { static ArithmeticOptimizerOptions Default( RewriterConfig::Toggle opt_level) { ArithmeticOptimizerOptions options; - // TODO(ezhulenev): enable combine_add_to_addn by default after 1.8 - // release cut + // TODO(ezhulenev): enable by default after 1.8 release cut if (opt_level == RewriterConfig::AGGRESSIVE) { options.combine_add_to_addn = true; + options.minimize_broadcasts = true; } return options; } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e117341ba3..9677175d2e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -93,6 +93,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; options.hoist_common_factor_out_of_aggregation = false; + options.minimize_broadcasts = false; options.remove_identity_transpose = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; @@ -113,6 +114,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.hoist_common_factor_out_of_aggregation = true; } + void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.minimize_broadcasts = true; + } + void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_identity_transpose = true; @@ -1841,5 +1847,160 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { EXPECT_EQ(5, found); } +TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT); + auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT); + + auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b); + auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c); + + auto outputs = ops::Identity(s.WithOpName("outputs"), mul2); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyMinimizeBroadcasts(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // * * + // / \ / \ + // * c --> * b + // / \ / \ + // a b a c + NodeMap node_map(&output); + + const NodeDef* mul1_node = node_map.GetNode("mul1"); + ASSERT_NE(mul1_node, nullptr); + EXPECT_EQ("a", mul1_node->input(0)); + EXPECT_EQ("c", mul1_node->input(1)); + + const NodeDef* mul2_node = node_map.GetNode("mul2"); + ASSERT_NE(mul2_node, nullptr); + EXPECT_EQ("mul1", mul2_node->input(0)); + EXPECT_EQ("b", mul2_node->input(1)); +} + +TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT); + auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT); + auto d = ops::Variable(s.WithOpName("d"), {32}, DT_FLOAT); + auto e = ops::Variable(s.WithOpName("e"), {32}, DT_FLOAT); + + auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b); + auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c); + auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d); + auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e); + + auto outputs = ops::Identity(s.WithOpName("outputs"), mul4); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyMinimizeBroadcasts(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: Graph is "flattened" and + // largest shape pushed to the top. + // + // * + // / \ + // * e * + // / \ / \ + // * d * b + // / \ / \ + // * c --> * * + // / \ / \ / \ + // a b a c d e + NodeMap node_map(&output); + + const NodeDef* mul1_node = node_map.GetNode("mul1"); + ASSERT_NE(mul1_node, nullptr); + EXPECT_EQ("a", mul1_node->input(0)); + EXPECT_EQ("c", mul1_node->input(1)); + + const NodeDef* mul2_node = node_map.GetNode("mul2"); + ASSERT_NE(mul2_node, nullptr); + EXPECT_EQ("d", mul2_node->input(0)); + EXPECT_EQ("e", mul2_node->input(1)); + + const NodeDef* mul3_node = node_map.GetNode("mul3"); + ASSERT_NE(mul3_node, nullptr); + EXPECT_EQ("mul1", mul3_node->input(0)); + EXPECT_EQ("mul2", mul3_node->input(1)); + + const NodeDef* mul4_node = node_map.GetNode("mul4"); + ASSERT_NE(mul4_node, nullptr); + EXPECT_EQ("mul3", mul4_node->input(0)); + EXPECT_EQ("b", mul4_node->input(1)); +} + +TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // [a, b, c] - scalars, [d] - matrix + auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT); + auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT); + auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT); + + auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b); + auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d); + auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2); + + auto outputs = ops::Identity(s.WithOpName("outputs"), mul3); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyMinimizeBroadcasts(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // * + // / \ + // * * D + // / \ / \ + // * * -> * c + // / \ / \ / \ + // a b c D a b + NodeMap node_map(&output); + + const NodeDef* mul1_node = node_map.GetNode("mul1"); + ASSERT_NE(mul1_node, nullptr); + EXPECT_EQ("a", mul1_node->input(0)); + EXPECT_EQ("b", mul1_node->input(1)); + + const NodeDef* mul2_node = node_map.GetNode("mul2"); + ASSERT_NE(mul2_node, nullptr); + EXPECT_EQ("mul1", mul2_node->input(0)); + EXPECT_EQ("c", mul2_node->input(1)); + + const NodeDef* mul3_node = node_map.GetNode("mul3"); + ASSERT_NE(mul3_node, nullptr); + EXPECT_EQ("D", mul3_node->input(0)); + EXPECT_EQ("mul2", mul3_node->input(1)); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index 7ed0474861..072f772946 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -134,6 +134,18 @@ class GraphOptimizerStage { // and remove template parameter. virtual Status TrySimplify(NodeDef* node, Result* result) = 0; + // Return InvalidArgumentError if node is not supported by the optimizer + // stage. + // TODO(ezhulenev): make this check part of non-virtual public API + // (TrySimplify), and make virtual implementation protected. + Status EnsureNodeIsSupported(const NodeDef* node) const { + return IsSupported(node) + ? Status::OK() + : errors::InvalidArgument( + "Node ", node->name(), " is not supported by optimizer ", + optimizer_name_, " and stage ", stage_name_); + } + // Get a name for a new node, created by this stage, based on one or multiple // nodes of an original graph. const string OptimizedNodeName(const NodeScopeAndName& node) const { |