diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-29 14:05:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 14:08:43 -0700 |
commit | 1da7159ee561bacd0b5ec9f30061d790590af846 (patch) | |
tree | 3f828bc57e065e1a4a63d6ed5eb1c9950318d7ae | |
parent | dcaa037571ab0933977f70574f4f78875155ae20 (diff) |
UnaryOpsComposition arithmetic optimizer.
PiperOrigin-RevId: 202703970
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 203 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 8 |
5 files changed, 258 insertions, 21 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 44f1d8ecf5..61bf566779 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1262,6 +1262,7 @@ cc_library( "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 72ca3c3fa2..28072c2df3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -263,6 +263,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { ctx().nodes_to_preserve->end(); } + // TODO(ezhulenev): move to GraphOptimizerStage? + bool IsDrivenByControlDependency(const NodeDef& node) const { + return std::any_of(node.input().begin(), node.input().end(), + IsControlInput); + } + + // TODO(ezhulenev): move to GraphOptimizerStage? + bool DrivesControlDependency(const NodeDef& node) const { + int position; + for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { + for (int i = 0; i < output->input_size(); ++i) { + auto input = output->input(i); + string name = ParseNodeName(input, &position); + if (name == node.name() && /*control input*/ position < 0) { + return true; + } + } + } + return false; + } + private: // Extended context required for ArithmeticOptimizer. const ArithmeticOptimizerContext ctx_ext_; @@ -393,27 +414,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { is_broadcastable); } - // TODO(ezhulenev): move to GraphOptimizerStage? - bool IsDrivenByControlDependency(const NodeDef& node) const { - return std::any_of(node.input().begin(), node.input().end(), - IsControlInput); - } - - // TODO(ezhulenev): move to GraphOptimizerStage? - bool DrivesControlDependency(const NodeDef& node) const { - int position; - for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { - for (int i = 0; i < output->input_size(); ++i) { - auto input = output->input(i); - string name = ParseNodeName(input, &position); - if (name == node.name() && /*control input*/ position < 0) { - return true; - } - } - } - return false; - } - string ShapeSignature(const TensorShapeProto& shape) const { string signature = strings::StrCat("rank:", shape.dim_size(), ":dim"); for (int i = 0; i < shape.dim_size(); ++i) @@ -2719,6 +2719,165 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { } }; +// Replace a chain of type&shape preserving unary ops with a +// '_UnaryOpsComposition' node. +// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't +// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?). +class UnaryOpsComposition : public ArithmeticOptimizerStage { + public: + explicit UnaryOpsComposition(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) { + // WARN: This should be consistent with unary_ops_composition.cc. + // clang-format off + supported_ops_ = {// Ops defined via Eigen scalar ops. + {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Acos", {DT_FLOAT, DT_DOUBLE}}, + {"Acosh", {DT_FLOAT, DT_DOUBLE}}, + {"Asin", {DT_FLOAT, DT_DOUBLE}}, + {"Asinh", {DT_FLOAT, DT_DOUBLE}}, + {"Atan", {DT_FLOAT, DT_DOUBLE}}, + {"Atanh", {DT_FLOAT, DT_DOUBLE}}, + {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cosh", {DT_FLOAT, DT_DOUBLE}}, + {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rint", {DT_FLOAT, DT_DOUBLE}}, + {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sinh", {DT_FLOAT, DT_DOUBLE}}, + {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Tan", {DT_FLOAT, DT_DOUBLE}}, + {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + // Additional ops that are not part of the Eigen. + {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}}; + // clang-format on + } + ~UnaryOpsComposition() override = default; + + bool IsSupported(const NodeDef* node) const override { + return CanOptimize(*node); + } + + Status TrySimplify(NodeDef* root, string* simplified_node_name) override { + DataType dtype = root->attr().at("T").type(); + + // Keep a trace of all supported input nodes that can be fused together. + std::vector<string> op_nodes = {root->name()}; + std::vector<string> op_names = {root->op()}; + + // Check if we should follow input(0) while building an op composition. + const auto predicate_fn = [&](const NodeDef& input) { + if (input.name() == root->name()) return true; + + bool follow_input_node = + dtype == GetDataTypeFromAttr(input, "T") && + NumNonControlDataOutputs(input, *ctx().node_map) == 1 && + CanOptimize(input); + + if (follow_input_node) { + op_nodes.push_back(input.name()); + op_names.push_back(input.op()); + } + + return follow_input_node; + }; + + NodeDef* last_op = GetTailOfChain( + *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn); + + // We were not able to find a chain that can be replaced. + if (op_names.size() == 1) return Status::OK(); + + // Do not add fused nodes to any other chain. + std::for_each(op_nodes.begin(), op_nodes.end(), + [this](const string& name) { AddToFusedNodes(name); }); + + // Reverse the trace to get correct composition computation order. + std::reverse(op_names.begin(), op_names.end()); + + VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=[" + << str_util::Join(op_names, ", ") << "]"; + + NodeDef* composition_node = ctx().optimized_graph->add_node(); + composition_node->set_name( + strings::StrCat(root->name(), "/unary_ops_composition")); + composition_node->set_op("_UnaryOpsComposition"); + composition_node->add_input(last_op->input(0)); + composition_node->set_device(root->device()); + + auto attr = composition_node->mutable_attr(); + SetAttrValue(dtype, &(*attr)["T"]); + SetAttrValue(op_names, &(*attr)["op_names"]); + + ctx().node_map->AddNode(composition_node->name(), composition_node); + ctx().node_map->AddOutput(NodeName(last_op->input(0)), + composition_node->name()); + + *simplified_node_name = composition_node->name(); + + return Status::OK(); + } + + private: + bool CanOptimize(const NodeDef& node) const { + DataType dtype = GetDataTypeFromAttr(node, "T"); + if (!IsSupported(node.op(), dtype)) { + return false; + } + if (IsInPreserveSet(node)) { + return false; + } + if (!NodeIsOnCpu(node)) { + return false; + } + if (NodeIsAlreadyFused(node)) { + return false; + } + return !(IsDrivenByControlDependency(node) || + DrivesControlDependency(node)); + } + + // UnaryOpsComposition is defined only for CPU. + bool NodeIsOnCpu(const NodeDef& node) const { + using str_util::StartsWith; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) && + StartsWith(device, DEVICE_CPU); + } + + bool NodeIsAlreadyFused(const NodeDef& node) const { + return fused_nodes_.count(node.name()) > 0; + } + + void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); } + + // Check if an op is supported by the _UnaryOpsComposition for the given type. + bool IsSupported(const string& op_name, DataType dtype) const { + const auto it = supported_ops_.find(op_name); + return it != supported_ops_.end() && it->second.count(dtype) > 0; + } + + std::unordered_map<string, std::set<DataType>> supported_ops_; + std::unordered_set<string> fused_nodes_; +}; + } // namespace class UniqueNodes { @@ -3001,6 +3160,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); if (options_.convert_expm1) pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext); + if (options_.unary_ops_composition) + pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 45a5f65b81..551c3652bf 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -78,6 +78,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool convert_pow = true; bool convert_log1p = true; bool convert_expm1 = true; + bool unary_ops_composition = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 3f6c04a5b5..54fdc01adb 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -141,6 +141,9 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.dedup_computations = false; options.combine_add_to_addn = false; options.convert_sqrt_div_to_rsqrt_mul = false; + options.convert_pow = false; + options.convert_log1p = false; + options.optimize_max_or_min_of_monotonic = false; options.fold_conjugate_into_transpose = false; options.fold_multiply_into_conv = false; options.fold_transpose_into_matmul = false; @@ -158,6 +161,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.reorder_cast_and_transpose = false; options.replace_mul_with_square = false; options.simplify_aggregation = false; + options.unary_ops_composition = false; optimizer->options_ = options; } @@ -279,6 +283,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.convert_expm1 = true; } + + void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.unary_ops_composition = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -3201,5 +3210,62 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) { EXPECT_EQ(2, required_node_count); } +TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x); + Output log = ops::Log(s.WithOpName("log"), sqrt); + Output relu = ops::Relu(s.WithOpName("relu"), log); + Output final_out = ops::Identity(s.WithOpName("final_out"), relu); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyUnaryOpsComposition(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + EXPECT_EQ(3, output.node_size()); + + // Check that Sqrt/Log/Relu were replaced with a single op. + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "final_out") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("relu/unary_ops_composition", node.input(0)); + ++required_node_count; + } else if (node.name() == "relu/unary_ops_composition") { + EXPECT_EQ("_UnaryOpsComposition", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + + auto op_names = node.attr().at("op_names").list().s(); + EXPECT_EQ(3, op_names.size()); + EXPECT_EQ("Sqrt", op_names[0]); + EXPECT_EQ("Log", op_names[1]); + EXPECT_EQ("Relu", op_names[2]); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index dee34d434b..466f601471 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3385,6 +3385,14 @@ cc_library( ], ) +# Kernels for the nodes intented to be added to the graph by the Grappler optimizers. +cc_library( + name = "grappler", + deps = [ + ":unary_ops_composition", + ], +) + NN_DEPS = [ ":bounds_check", ":conv_2d", |