aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc219
1 files changed, 190 insertions, 29 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d8c5d09c4d..3ab2211694 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -227,6 +227,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx().nodes_to_preserve->end();
}
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@@ -357,27 +378,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
is_broadcastable);
}
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
- }
-
string ShapeSignature(const TensorShapeProto& shape) const {
string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
for (int i = 0; i < shape.dim_size(); ++i)
@@ -2648,6 +2648,172 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
}
};
+// Replace a chain of type&shape preserving unary ops with a
+// '_UnaryOpsComposition' node.
+// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
+// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
+class UnaryOpsComposition : public ArithmeticOptimizerStage {
+ public:
+ explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
+ // WARN: This should be consistent with unary_ops_composition.cc.
+ // clang-format off
+ supported_ops_ = {// Ops defined via Eigen scalar ops.
+ {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Acos", {DT_FLOAT, DT_DOUBLE}},
+ {"Acosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Asin", {DT_FLOAT, DT_DOUBLE}},
+ {"Asinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Atan", {DT_FLOAT, DT_DOUBLE}},
+ {"Atanh", {DT_FLOAT, DT_DOUBLE}},
+ {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rint", {DT_FLOAT, DT_DOUBLE}},
+ {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Tan", {DT_FLOAT, DT_DOUBLE}},
+ {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ // Additional ops that are not part of the Eigen.
+ {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
+ // clang-format on
+ }
+ ~UnaryOpsComposition() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return CanOptimize(*node) &&
+ // Check that this node was not already a root of a fused chain. If
+ // graph optimization runs twice without pruning in between,
+ // fused_nodes_ will not have this information.
+ !ctx().node_map->NodeExists(OptimizedNodeName(*node));
+ }
+
+ Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
+ DataType dtype = root->attr().at("T").type();
+
+ // Keep a trace of all supported input nodes that can be fused together.
+ std::vector<string> op_nodes = {root->name()};
+ std::vector<string> op_names = {root->op()};
+
+ // Check if we should follow input(0) while building an op composition.
+ const auto predicate_fn = [&](const NodeDef& input) {
+ if (input.name() == root->name()) return true;
+
+ bool follow_input_node =
+ dtype == GetDataTypeFromAttr(input, "T") &&
+ NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
+ CanOptimize(input);
+
+ if (follow_input_node) {
+ op_nodes.push_back(input.name());
+ op_names.push_back(input.op());
+ }
+
+ return follow_input_node;
+ };
+
+ NodeDef* last_op = GetTailOfChain(
+ *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
+
+ // We were not able to find a chain that can be replaced.
+ if (op_names.size() == 1) return Status::OK();
+
+ // Do not add fused nodes to any other chain.
+ std::for_each(op_nodes.begin(), op_nodes.end(),
+ [this](const string& name) { AddToFusedNodes(name); });
+
+ // Reverse the trace to get correct composition computation order.
+ std::reverse(op_names.begin(), op_names.end());
+
+ VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
+ << str_util::Join(op_names, ", ") << "]";
+
+ NodeDef* composition_node = ctx().optimized_graph->add_node();
+ composition_node->set_name(OptimizedNodeName(*root));
+ composition_node->set_op("_UnaryOpsComposition");
+ composition_node->add_input(last_op->input(0));
+ composition_node->set_device(root->device());
+
+ auto attr = composition_node->mutable_attr();
+ SetAttrValue(dtype, &(*attr)["T"]);
+ SetAttrValue(op_names, &(*attr)["op_names"]);
+
+ ctx().node_map->AddNode(composition_node->name(), composition_node);
+ ctx().node_map->AddOutput(NodeName(last_op->input(0)),
+ composition_node->name());
+
+ *simplified_node_name = composition_node->name();
+
+ return Status::OK();
+ }
+
+ private:
+ bool CanOptimize(const NodeDef& node) const {
+ DataType dtype = GetDataTypeFromAttr(node, "T");
+ if (!IsSupported(node.op(), dtype)) {
+ return false;
+ }
+ if (IsInPreserveSet(node)) {
+ return false;
+ }
+ if (!NodeIsOnCpu(node)) {
+ return false;
+ }
+ if (NodeIsAlreadyFused(node)) {
+ return false;
+ }
+ return !(IsDrivenByControlDependency(node) ||
+ DrivesControlDependency(node));
+ }
+
+ // UnaryOpsComposition is defined only for CPU.
+ bool NodeIsOnCpu(const NodeDef& node) const {
+ using str_util::StartsWith;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
+ StartsWith(device, DEVICE_CPU);
+ }
+
+ bool NodeIsAlreadyFused(const NodeDef& node) const {
+ return fused_nodes_.count(node.name()) > 0;
+ }
+
+ string OptimizedNodeName(const NodeDef& node) const {
+ return strings::StrCat(node.name(), "/unary_ops_composition");
+ }
+
+ void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
+
+ // Check if an op is supported by the _UnaryOpsComposition for the given type.
+ bool IsSupported(const string& op_name, DataType dtype) const {
+ const auto it = supported_ops_.find(op_name);
+ return it != supported_ops_.end() && it->second.count(dtype) > 0;
+ }
+
+ std::unordered_map<string, std::set<DataType>> supported_ops_;
+ std::unordered_set<string> fused_nodes_;
+};
+
} // namespace
class UniqueNodes {
@@ -2841,14 +3007,7 @@ void ArithmeticOptimizer::DedupComputations() {
// Delete duplicates
if (fetch_nodes_known_ && !duplicates.empty()) {
- int last = optimized_graph_->node_size() - 1;
- for (auto it = duplicates.rbegin(); it != duplicates.rend(); ++it) {
- int index = *it;
- optimized_graph_->mutable_node()->SwapElements(index, last);
- last--;
- }
- optimized_graph_->mutable_node()->DeleteSubrange(last + 1,
- duplicates.size());
+ EraseNodesFromGraph(duplicates, optimized_graph_);
// Rebuild the NodeMap which was invalidated by the node swapping above.
node_map_.reset(new NodeMap(optimized_graph_));
}
@@ -2928,6 +3087,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
+ if (options_.unary_ops_composition)
+ pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");