diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-29 14:05:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 14:08:43 -0700 |
commit | 1da7159ee561bacd0b5ec9f30061d790590af846 (patch) | |
tree | 3f828bc57e065e1a4a63d6ed5eb1c9950318d7ae /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | dcaa037571ab0933977f70574f4f78875155ae20 (diff) |
UnaryOpsComposition arithmetic optimizer.
PiperOrigin-RevId: 202703970
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 203 |
1 files changed, 182 insertions, 21 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 72ca3c3fa2..28072c2df3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -263,6 +263,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { ctx().nodes_to_preserve->end(); } + // TODO(ezhulenev): move to GraphOptimizerStage? + bool IsDrivenByControlDependency(const NodeDef& node) const { + return std::any_of(node.input().begin(), node.input().end(), + IsControlInput); + } + + // TODO(ezhulenev): move to GraphOptimizerStage? + bool DrivesControlDependency(const NodeDef& node) const { + int position; + for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { + for (int i = 0; i < output->input_size(); ++i) { + auto input = output->input(i); + string name = ParseNodeName(input, &position); + if (name == node.name() && /*control input*/ position < 0) { + return true; + } + } + } + return false; + } + private: // Extended context required for ArithmeticOptimizer. const ArithmeticOptimizerContext ctx_ext_; @@ -393,27 +414,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { is_broadcastable); } - // TODO(ezhulenev): move to GraphOptimizerStage? - bool IsDrivenByControlDependency(const NodeDef& node) const { - return std::any_of(node.input().begin(), node.input().end(), - IsControlInput); - } - - // TODO(ezhulenev): move to GraphOptimizerStage? - bool DrivesControlDependency(const NodeDef& node) const { - int position; - for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { - for (int i = 0; i < output->input_size(); ++i) { - auto input = output->input(i); - string name = ParseNodeName(input, &position); - if (name == node.name() && /*control input*/ position < 0) { - return true; - } - } - } - return false; - } - string ShapeSignature(const TensorShapeProto& shape) const { string signature = strings::StrCat("rank:", shape.dim_size(), ":dim"); for (int i = 0; i < shape.dim_size(); ++i) @@ -2719,6 +2719,165 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { } }; +// Replace a chain of type&shape preserving unary ops with a +// '_UnaryOpsComposition' node. +// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't +// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?). +class UnaryOpsComposition : public ArithmeticOptimizerStage { + public: + explicit UnaryOpsComposition(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) { + // WARN: This should be consistent with unary_ops_composition.cc. + // clang-format off + supported_ops_ = {// Ops defined via Eigen scalar ops. + {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Acos", {DT_FLOAT, DT_DOUBLE}}, + {"Acosh", {DT_FLOAT, DT_DOUBLE}}, + {"Asin", {DT_FLOAT, DT_DOUBLE}}, + {"Asinh", {DT_FLOAT, DT_DOUBLE}}, + {"Atan", {DT_FLOAT, DT_DOUBLE}}, + {"Atanh", {DT_FLOAT, DT_DOUBLE}}, + {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cosh", {DT_FLOAT, DT_DOUBLE}}, + {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rint", {DT_FLOAT, DT_DOUBLE}}, + {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sinh", {DT_FLOAT, DT_DOUBLE}}, + {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Tan", {DT_FLOAT, DT_DOUBLE}}, + {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + // Additional ops that are not part of the Eigen. + {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}}; + // clang-format on + } + ~UnaryOpsComposition() override = default; + + bool IsSupported(const NodeDef* node) const override { + return CanOptimize(*node); + } + + Status TrySimplify(NodeDef* root, string* simplified_node_name) override { + DataType dtype = root->attr().at("T").type(); + + // Keep a trace of all supported input nodes that can be fused together. + std::vector<string> op_nodes = {root->name()}; + std::vector<string> op_names = {root->op()}; + + // Check if we should follow input(0) while building an op composition. + const auto predicate_fn = [&](const NodeDef& input) { + if (input.name() == root->name()) return true; + + bool follow_input_node = + dtype == GetDataTypeFromAttr(input, "T") && + NumNonControlDataOutputs(input, *ctx().node_map) == 1 && + CanOptimize(input); + + if (follow_input_node) { + op_nodes.push_back(input.name()); + op_names.push_back(input.op()); + } + + return follow_input_node; + }; + + NodeDef* last_op = GetTailOfChain( + *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn); + + // We were not able to find a chain that can be replaced. + if (op_names.size() == 1) return Status::OK(); + + // Do not add fused nodes to any other chain. + std::for_each(op_nodes.begin(), op_nodes.end(), + [this](const string& name) { AddToFusedNodes(name); }); + + // Reverse the trace to get correct composition computation order. + std::reverse(op_names.begin(), op_names.end()); + + VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=[" + << str_util::Join(op_names, ", ") << "]"; + + NodeDef* composition_node = ctx().optimized_graph->add_node(); + composition_node->set_name( + strings::StrCat(root->name(), "/unary_ops_composition")); + composition_node->set_op("_UnaryOpsComposition"); + composition_node->add_input(last_op->input(0)); + composition_node->set_device(root->device()); + + auto attr = composition_node->mutable_attr(); + SetAttrValue(dtype, &(*attr)["T"]); + SetAttrValue(op_names, &(*attr)["op_names"]); + + ctx().node_map->AddNode(composition_node->name(), composition_node); + ctx().node_map->AddOutput(NodeName(last_op->input(0)), + composition_node->name()); + + *simplified_node_name = composition_node->name(); + + return Status::OK(); + } + + private: + bool CanOptimize(const NodeDef& node) const { + DataType dtype = GetDataTypeFromAttr(node, "T"); + if (!IsSupported(node.op(), dtype)) { + return false; + } + if (IsInPreserveSet(node)) { + return false; + } + if (!NodeIsOnCpu(node)) { + return false; + } + if (NodeIsAlreadyFused(node)) { + return false; + } + return !(IsDrivenByControlDependency(node) || + DrivesControlDependency(node)); + } + + // UnaryOpsComposition is defined only for CPU. + bool NodeIsOnCpu(const NodeDef& node) const { + using str_util::StartsWith; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) && + StartsWith(device, DEVICE_CPU); + } + + bool NodeIsAlreadyFused(const NodeDef& node) const { + return fused_nodes_.count(node.name()) > 0; + } + + void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); } + + // Check if an op is supported by the _UnaryOpsComposition for the given type. + bool IsSupported(const string& op_name, DataType dtype) const { + const auto it = supported_ops_.find(op_name); + return it != supported_ops_.end() && it->second.count(dtype) > 0; + } + + std::unordered_map<string, std::set<DataType>> supported_ops_; + std::unordered_set<string> fused_nodes_; +}; + } // namespace class UniqueNodes { @@ -3001,6 +3160,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); if (options_.convert_expm1) pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext); + if (options_.unary_ops_composition) + pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); |