diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 1312 |
1 files changed, 947 insertions, 365 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index ca3f84a81d..97862d1ed0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -101,38 +101,6 @@ bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) { return false; } -template <typename T> -bool IsInnerMatrixTranspose(const std::vector<T>& perm) { - const T n = perm.size(); - if (n < 2) { - return false; - } - for (T i = 0; i < n - 2; ++i) { - if (perm[i] != i) { - return false; - } - } - return perm[n - 1] == n - 2 && perm[n - 2] == n - 1; -} - -bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node, - const NodeMap* node_map) { - if (transpose_node.op() != "Transpose" && - transpose_node.op() != "ConjugateTranspose") { - return false; - } - const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1)); - std::vector<int> perm32; - if (ValuesFromConstNode(*perm_node, &perm32)) { - return IsInnerMatrixTranspose(perm32); - } - std::vector<int64> perm64; - if (ValuesFromConstNode(*perm_node, &perm64)) { - return IsInnerMatrixTranspose(perm64); - } - return false; -} - bool MaybeAddControlInput(const string& new_input, NodeDef* node, GraphDef* graph, NodeMap* node_map) { bool already_exists = false; @@ -155,12 +123,6 @@ void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) { (*node->mutable_attr())[attr_name].set_type(dtype); } -void FlipBooleanAttr(const string& attr_name, NodeDef* node) { - const bool old_value = - !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b(); - (*node->mutable_attr())[attr_name].set_b(!old_value); -} - string SourceDataTypeAttrName(const NodeDef& node) { if (node.op() == "Bitcast") { return "T"; @@ -265,6 +227,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { ctx().nodes_to_preserve->end(); } + // TODO(ezhulenev): move to GraphOptimizerStage? + bool IsDrivenByControlDependency(const NodeDef& node) const { + return std::any_of(node.input().begin(), node.input().end(), + IsControlInput); + } + + // TODO(ezhulenev): move to GraphOptimizerStage? + bool DrivesControlDependency(const NodeDef& node) const { + int position; + for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { + for (int i = 0; i < output->input_size(); ++i) { + auto input = output->input(i); + string name = ParseNodeName(input, &position); + if (name == node.name() && /*control input*/ position < 0) { + return true; + } + } + } + return false; + } + private: // Extended context required for ArithmeticOptimizer. const ArithmeticOptimizerContext ctx_ext_; @@ -395,27 +378,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { is_broadcastable); } - // TODO(ezhulenev): move to GraphOptimizerStage? - bool IsDrivenByControlDependency(const NodeDef& node) const { - return std::any_of(node.input().begin(), node.input().end(), - IsControlInput); - } - - // TODO(ezhulenev): move to GraphOptimizerStage? - bool DrivesControlDependency(const NodeDef& node) const { - int position; - for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { - for (int i = 0; i < output->input_size(); ++i) { - auto input = output->input(i); - string name = ParseNodeName(input, &position); - if (name == node.name() && /*control input*/ position < 0) { - return true; - } - } - } - return false; - } - string ShapeSignature(const TensorShapeProto& shape) const { string signature = strings::StrCat("rank:", shape.dim_size(), ":dim"); for (int i = 0; i < shape.dim_size(); ++i) @@ -1122,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* tail = node; - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); + // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. + if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); + } NodeDef* first_transpose; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); @@ -1757,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage { ~RemoveIdempotentStage() override = default; bool IsSupported(const NodeDef* node) const override { - return IsIdempotent(*node) && !IsInPreserveSet(*node); + return node->input_size() == 1 && IsIdempotent(*node) && + !IsInPreserveSet(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - auto root_scope_and_name = ParseNodeScopeAndName(node->name()); - const string new_name = OptimizedNodeName(root_scope_and_name); - if (input->op() == node->op() && input->device() == node->device() && - IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) { - NodeDef* new_input_node = AddCopyNode(new_name, input); - ForwardControlDependencies(new_input_node, {node}); - *simplified_node_name = new_input_node->name(); + if (input->op() == node->op() && input->device() == node->device()) { + *simplified_node_name = node->input(0); } return Status::OK(); } @@ -1958,6 +1919,901 @@ class ReorderCastAndTranspose : public ArithmeticOptimizerStage { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } }; +// Fold a multiply of a scalar into the following convolution. This folding +// can jump across nodes that merely reorders data (such as reshape and +// transpose). For example, we can optimize +// +// +// Conv2D Conv2D +// / \ / \ +// Transpose weights* -> Transpose Mul +// | | / \ +// Mul | weights scale +// / \ | +// input scale** input +// +// *) weights must be a const +// **) scale must be a const scalar +// +// When `weights` and `scale` are constant, `Mul` in the optimized graph can be +// constant-folded, also weights tend to be smaller than the activations. +// +// TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and +// Conv?DBackpropInput. +class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { + public: + explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {} + ~FoldMultiplyIntoConv() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsConv2D(*node) || IsConv3D(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { +#define TF_RETURN_IF_TRUE(...) \ + if ((__VA_ARGS__)) return Status::OK() + + NodeDef* conv = node; + + NodeDef* weights; + TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights)); + + // Fold the multiply to conv only when the weights are constant, so the + // multiply can be constant-folded. + // + // TODO(jingyue): When the weights aren't constant, this should also help + // performance a bit and memory usage a lot, since the weights tend to be + // smaller than the activations. + TF_RETURN_IF_TRUE(!IsConstant(*weights)); + + // Verify that this node was not already optimized. + const string scaled_weights_node_name = + OptimizedNodeName(ParseNodeScopeAndName(weights->name()), + strings::StrCat("scaled", "_", conv->name())); + + TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name)); + + // Find the tail of value preserving chain entering the Conv node. + NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map, + *ctx().nodes_to_preserve); + + NodeDef* source; + TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source)); + + // Check that value preserving chain is the only consumer of the Mul output. + TF_RETURN_IF_TRUE(!IsMul(*source)); + TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1); + + const NodeDef* mul = source; + + // TODO(jingyue): handle the case where `scale` is 0-th operand. + NodeDef* scale; // scalar multiplier fot the input tensor + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale)); + TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input)); + + // Check that 'scale * weight' can be const folded. + TF_RETURN_IF_TRUE(!IsConstant(*scale)); + TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() != + weights->attr().at("dtype").type()); + + // Check that `scale` is a scalar. + const TensorProto& scale_tensor = scale->attr().at("value").tensor(); + bool scale_is_a_scalar = scale_tensor.has_tensor_shape() && + scale_tensor.tensor_shape().dim_size() == 0; + TF_RETURN_IF_TRUE(!scale_is_a_scalar); + + // At this point all preconditions are met, and we safely do the rewrite. + VLOG(3) << "Fold multiply into conv: conv=" << conv->name() + << " mul=" << mul->name() << " weights=" << weights->name(); + + // Create new node `scaled_weights`. + NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name); + scaled_weights->set_op("Mul"); + scaled_weights->set_device(weights->device()); + (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype"); + AddToOptimizationQueue(scaled_weights); + + // Link in its inputs. + scaled_weights->add_input(conv->input(1)); + ctx().node_map->AddOutput(weights->name(), scaled_weights->name()); + scaled_weights->add_input(mul->input(1)); + ctx().node_map->AddOutput(scale->name(), scaled_weights->name()); + ForwardControlDependencies(scaled_weights, {source}); + + // Update `conv`'s weights to `scaled_weights`. + conv->set_input(1, scaled_weights->name()); + ctx().node_map->UpdateInput(conv->name(), weights->name(), + scaled_weights->name()); + AddToOptimizationQueue(conv); + + // Update `tail` node to bypass `mul` because it's folded to the weights. + tail->set_input(0, mul->input(0)); + ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name()); + AddToOptimizationQueue(tail); + *simplified_node_name = conv->name(); + + return Status::OK(); +#undef TF_RETURN_IF_TRUE + } +}; + +// Fold Transpose into matrix multiplication. +class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { + public: + explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {} + ~FoldTransposeIntoMatMul() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMatMul(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(matmul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + NodeDef* a; + NodeDef* b; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b)); + + bool is_complex = false; + if (node->op() != "SparseMatMul") { + const DataType type = GetDataTypeFromAttr(*node, "T"); + is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); + } + + const std::set<string> foldable_transpose_ops = + !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"} + : (node->op() == "BatchMatMul" + ? std::set<string>{"ConjugateTranspose"} + : std::set<string>{"Transpose"}); + + const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 && + IsInnerMatrixTransposeNode(*a, ctx().node_map); + const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 && + IsInnerMatrixTransposeNode(*b, ctx().node_map); + if (!a_is_foldable && !b_is_foldable) return Status::OK(); + + NodeDef* new_op = AddCopyNode(optimized_node_name, node); + + if (a_is_foldable) { + const string attr_a = + node->op() == "BatchMatMul" ? "adj_x" : "transpose_a"; + FlipBooleanAttr(attr_a, new_op); + new_op->set_input(0, a->input(0)); + ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0)); + } + + if (b_is_foldable) { + const string attr_b = + node->op() == "BatchMatMul" ? "adj_y" : "transpose_b"; + FlipBooleanAttr(attr_b, new_op); + new_op->set_input(1, b->input(0)); + ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0)); + } + + std::vector<const NodeDef*> deps_to_forward = {node}; + if (a_is_foldable) deps_to_forward.push_back(a); + if (b_is_foldable) deps_to_forward.push_back(b); + ForwardControlDependencies(new_op, deps_to_forward); + + return Status::OK(); + } + + private: + void FlipBooleanAttr(const string& attr_name, NodeDef* node) { + const bool old_value = + !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b(); + (*node->mutable_attr())[attr_name].set_b(!old_value); + } + + template <typename T> + bool IsInnerMatrixTranspose(const std::vector<T>& perm) { + const T n = perm.size(); + if (n < 2) { + return false; + } + for (T i = 0; i < n - 2; ++i) { + if (perm[i] != i) { + return false; + } + } + return perm[n - 1] == n - 2 && perm[n - 2] == n - 1; + } + + bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node, + const NodeMap* node_map) { + if (transpose_node.op() != "Transpose" && + transpose_node.op() != "ConjugateTranspose") { + return false; + } + const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1)); + std::vector<int> perm32; + if (ValuesFromConstNode(*perm_node, &perm32)) { + return IsInnerMatrixTranspose(perm32); + } + std::vector<int64> perm64; + if (ValuesFromConstNode(*perm_node, &perm64)) { + return IsInnerMatrixTranspose(perm64); + } + return false; + } +}; + +// Fold Transpose into matrix multiplication. +class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage { + public: + explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {} + ~FoldConjugateIntoTranspose() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(matmul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + + const NodeDef* transpose_op = node->op() == "Conj" ? input : node; + const NodeDef* conj_op = node->op() == "Conj" ? node : input; + + if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) && + IsConj(*conj_op)) { + NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op); + + // Flip the type of transpose op to absorb the conjugation. + new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose" + : "Transpose"); + new_op->set_input(0, input->input(0)); + ctx().node_map->UpdateInput(new_op->name(), node->name(), + input->input(0)); + ForwardControlDependencies(new_op, {node, input}); + *simplified_node_name = new_op->name(); + } + + return Status::OK(); + } +}; + +// Replace Mul node with identical inputs with a Square. +class ReplaceMulWithSquare : public ArithmeticOptimizerStage { + public: + explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {} + ~ReplaceMulWithSquare() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMul(*node) && node->input(0) == node->input(1); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName mul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(mul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + const DataType type = GetDataTypeFromAttr(*node, "T"); + bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); + + string task; + string device; + bool is_on_cpu = + DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + str_util::StrContains(device, DEVICE_CPU); + + if (!is_complex || is_on_cpu) { + NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); + new_square_node->set_op("Square"); + for (int i = 1; i < new_square_node->input_size(); ++i) { + new_square_node->set_input(i - 1, new_square_node->input(i)); + } + new_square_node->mutable_input()->RemoveLast(); + for (const string& input : new_square_node->input()) { + ctx().node_map->AddOutput(NodeName(input), new_square_node->name()); + } + *simplified_node_name = new_square_node->name(); + } + + return Status::OK(); + } +}; + +// Simplify aggregation (e.g. AddN) nodes: +// +// 1. Discard aggregate nodes with a single input and no control dependencies. +// +// 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to +// deduping or other rewrites) so we can get rid of the sum entirely. +// +// The expression (using AddN as an example of an aggregate op): +// AddN(x, x, x, ... ,x) +// <-- N terms --> +// can be rewritten to: +// Mul(Const(N), x)) +// +class SimplifyAggregation : public ArithmeticOptimizerStage { + public: + explicit SimplifyAggregation(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {} + ~SimplifyAggregation() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAggregate(*node) && NumNonControlInputs(*node) > 0; + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + // 1. Discard aggregate nodes with a single input and no control deps. + if (node->input_size() == 1) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + // 2. Rewrite aggregations of N >= 2 identical terms. + + // All non-control inputs must be identical. + bool all_equal = true; + int num_inputs = 1; + for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) break; + ++num_inputs; + if (node->input(i) != node->input(0)) { + all_equal = false; + break; + } + } + if (!all_equal) return Status::OK(); + + // And node should not be optimized earlier. + const NodeScopeAndName node_scope_and_name = + ParseNodeScopeAndName(node->name()); + const string optimized_const_name = + OptimizedNodeName(node_scope_and_name, "Const"); + const string optimized_mul_name = + OptimizedNodeName(node_scope_and_name, "Mul"); + + bool is_already_optimized = + ctx().node_map->NodeExists(optimized_const_name) || + ctx().node_map->NodeExists(optimized_mul_name); + + if (is_already_optimized) return Status::OK(); + + // At this point all preconditions are met, and we safely do the rewrite. + VLOG(3) << "Simplify aggregation with identical inputs: node=" + << node->name() << " num_inputs=" << num_inputs; + + // 1. Create constant node with value N. + const auto type = GetDataTypeFromAttr(*node, "T"); + Tensor t(type, TensorShape({})); + Status status = SetTensorValue(type, num_inputs, &t); + if (!status.ok()) { + return errors::Internal("Failed to create const node: ", + status.error_message()); + } + + TensorValue value(&t); + NodeDef* new_const_node = AddEmptyNode(optimized_const_name); + status = ConstantFolding::CreateNodeDef(new_const_node->name(), value, + new_const_node); + if (!status.ok()) { + return errors::Internal("Failed to create const node: ", + status.error_message()); + } + new_const_node->set_device(node->device()); + MaybeAddControlInput(NodeName(node->input(0)), new_const_node, + ctx().optimized_graph, ctx().node_map); + AddToOptimizationQueue(new_const_node); + + // 2. Replace the aggregate node with Mul(Const(N), x). + NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name); + new_mul_node->set_op("Mul"); + new_mul_node->set_device(node->device()); + SetDataTypeToAttr(type, "T", new_mul_node); + new_mul_node->add_input(new_const_node->name()); + ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name()); + new_mul_node->add_input(node->input(0)); + ctx().node_map->AddOutput(node->input(0), new_mul_node->name()); + + ForwardControlDependencies(new_mul_node, {node}); + *simplified_node_name = new_mul_node->name(); + + return Status::OK(); + } +}; + +class ConvertPowStage : public ArithmeticOptimizerStage { + public: + explicit ConvertPowStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {} + + bool IsSupported(const NodeDef* node) const override { + return IsPow(*node) && + ctx().graph_properties->GetInputProperties(node->name()).size() == 2; + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1]; + for (int i = 0; i < p.shape().dim_size(); ++i) { + if (p.shape().dim(i).size() < 0) { + // skip if p is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(p.shape()) && p.has_value()) { + Tensor pow(p.dtype(), p.shape()); + if (!pow.FromProto(p.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + p.value().DebugString()); + } + + complex128 prev, curr; + for (int i = 0; i < pow.NumElements(); ++i) { + TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); + if (i != 0 && curr != prev) { + // pow has different values on different elements. Skip. + return Status::OK(); + } + prev = curr; + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + if (curr == complex128(2, 0)) { + node->set_op("Square"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(1, 0)) { + node->set_op("Identity"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0.5, 0)) { + node->set_op("Sqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0, 0)) { + const auto& b = + ctx().graph_properties->GetInputProperties(node->name())[0]; + for (int i = 0; i < b.shape().dim_size(); ++i) { + if (b.shape().dim(i).size() < 0) { + // skip if b is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(b.shape()) && b.has_value()) { + Tensor base(b.dtype(), b.shape()); + if (!base.FromProto(b.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + b.value().DebugString()); + } + node->set_op("Const"); + Tensor c(base.dtype(), base.shape()); + for (int i = 0; i < c.NumElements(); ++i) { + TF_RETURN_IF_ERROR(SetElementToOne(i, &c)); + } + (*node->mutable_attr())["dtype"].set_type(base.dtype()); + c.AsProtoTensorContent( + (*node->mutable_attr())["value"].mutable_tensor()); + node->mutable_attr()->erase("T"); + node->set_input(0, AsControlDependency(x->name())); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } + } else if (curr == complex128(-0.5, 0)) { + node->set_op("Rsqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(-1, 0)) { + node->set_op("Reciprocal"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } + } + return Status::OK(); + } + + private: + Status GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return Status::OK(); + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return Status::OK(); + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return Status::OK(); + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return Status::OK(); + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return Status::OK(); + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t.dtype()); + } + } + + Status SetElementToOne(int i, Tensor* t) { + switch (t->dtype()) { + case DT_INT32: + t->flat<int32>()(i) = 1; + return Status::OK(); + case DT_INT64: + t->flat<int64>()(i) = 1L; + return Status::OK(); + case DT_FLOAT: + t->flat<float>()(i) = 1.0f; + return Status::OK(); + case DT_DOUBLE: + t->flat<double>()(i) = 1.0; + return Status::OK(); + case DT_COMPLEX64: + t->flat<complex64>()(i) = complex64(1); + return Status::OK(); + case DT_COMPLEX128: + t->flat<complex128>()(i) = complex128(1); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t->dtype()); + } + } +}; + +class ConvertLog1pStage : public ArithmeticOptimizerStage { + public: + explicit ConvertLog1pStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {} + ~ConvertLog1pStage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsLog(*node); } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (!IsAdd(*input)) { + return Status::OK(); + } + + if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { + return Status::OK(); + } + + bool modified = false; + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified)); + if (!modified) { + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified)); + } + if (modified) { + *simplified_node_name = node->name(); + } + return Status::OK(); + } + + private: + Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, + bool* modified) { + const auto& t = + ctx().graph_properties->GetInputProperties(input->name())[i]; + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[j]; + for (int k = 0; k < c.shape().dim_size(); ++k) { + // Skip if c shape is not fully determined. + if (c.shape().dim(k).size() < 0) { + return Status::OK(); + } + } + TensorShapeProto broadcast_shape; + if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { + return Status::OK(); + } + if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { + // skip if the non-constant tensor doesn't have the same shape after + // broadcast. + return Status::OK(); + } + if (TensorShape::IsValid(c.shape()) && c.has_value()) { + Tensor constant(c.dtype(), c.shape()); + if (!constant.FromProto(c.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + c.value().DebugString()); + } + complex128 element; + for (int k = 0; k < constant.NumElements(); ++k) { + if (!GetElement(constant, k, &element)) { + // input data type is not supported by log1p. Skip. + return Status::OK(); + } + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); + } + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); + node->set_op("Log1p"); + node->set_input(0, input->input(i)); + node->add_input(AsControlDependency(y->name())); + ForwardControlDependencies(node, {input}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(input); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + *modified = true; + } + return Status::OK(); + } + + bool GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; + } + } +}; + +// Performs conversions like: +// Max(Sqrt(x)) => Sqrt(Max(x)) +// Checks for a max/min reduction over element-wise monotonic functions, such +// as Sqrt, Sigmoid, Tanh, etc. +class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { + public: + explicit OptimizeMaxOrMinOfMonotonicStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx, + ctx_ext) {} + ~OptimizeMaxOrMinOfMonotonicStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMax(*node) || IsMin(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + NodeDef* inner_function; + TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); + // Optimize only if: + // 1. inner_function's Op is element-wise monotonic + // 2. inner_function's output is not being consumed elsewhere. + if (IsElementWiseMonotonic(*inner_function) && + (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + // Swap the first inputs of the inner function Op & the reduction Op. + NodeDef* inner_input; + TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); + reduction_node->set_input(0, inner_input->name()); + UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + } + return Status::OK(); + } + + void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + const string& node_name = node->name(); + const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + if (consumer->input(i) == node_name && consumer->name() != new_input) { + consumer->set_input(i, new_input); + ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); + } + } + AddToOptimizationQueue(consumer); + } + } +}; + +// Replace a chain of type&shape preserving unary ops with a +// '_UnaryOpsComposition' node. +// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't +// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?). +class UnaryOpsComposition : public ArithmeticOptimizerStage { + public: + explicit UnaryOpsComposition(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) { + // WARN: This should be consistent with unary_ops_composition.cc. + // clang-format off + supported_ops_ = {// Ops defined via Eigen scalar ops. + {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Acos", {DT_FLOAT, DT_DOUBLE}}, + {"Acosh", {DT_FLOAT, DT_DOUBLE}}, + {"Asin", {DT_FLOAT, DT_DOUBLE}}, + {"Asinh", {DT_FLOAT, DT_DOUBLE}}, + {"Atan", {DT_FLOAT, DT_DOUBLE}}, + {"Atanh", {DT_FLOAT, DT_DOUBLE}}, + {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Cosh", {DT_FLOAT, DT_DOUBLE}}, + {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rint", {DT_FLOAT, DT_DOUBLE}}, + {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Sinh", {DT_FLOAT, DT_DOUBLE}}, + {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Tan", {DT_FLOAT, DT_DOUBLE}}, + {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + // Additional ops that are not part of the Eigen. + {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}}, + {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}}; + // clang-format on + } + ~UnaryOpsComposition() override = default; + + bool IsSupported(const NodeDef* node) const override { + return CanOptimize(*node) && + // Check that this node was not already a root of a fused chain. If + // graph optimization runs twice without pruning in between, + // fused_nodes_ will not have this information. + !ctx().node_map->NodeExists(OptimizedNodeName(*node)); + } + + Status TrySimplify(NodeDef* root, string* simplified_node_name) override { + DataType dtype = root->attr().at("T").type(); + + // Keep a trace of all supported input nodes that can be fused together. + std::vector<string> op_nodes = {root->name()}; + std::vector<string> op_names = {root->op()}; + + // Check if we should follow input(0) while building an op composition. + const auto predicate_fn = [&](const NodeDef& input) { + if (input.name() == root->name()) return true; + + bool follow_input_node = + dtype == GetDataTypeFromAttr(input, "T") && + NumNonControlDataOutputs(input, *ctx().node_map) == 1 && + CanOptimize(input); + + if (follow_input_node) { + op_nodes.push_back(input.name()); + op_names.push_back(input.op()); + } + + return follow_input_node; + }; + + NodeDef* last_op = GetTailOfChain( + *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn); + + // We were not able to find a chain that can be replaced. + if (op_names.size() == 1) return Status::OK(); + + // Do not add fused nodes to any other chain. + std::for_each(op_nodes.begin(), op_nodes.end(), + [this](const string& name) { AddToFusedNodes(name); }); + + // Reverse the trace to get correct composition computation order. + std::reverse(op_names.begin(), op_names.end()); + + VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=[" + << str_util::Join(op_names, ", ") << "]"; + + NodeDef* composition_node = ctx().optimized_graph->add_node(); + composition_node->set_name(OptimizedNodeName(*root)); + composition_node->set_op("_UnaryOpsComposition"); + composition_node->add_input(last_op->input(0)); + composition_node->set_device(root->device()); + + auto attr = composition_node->mutable_attr(); + SetAttrValue(dtype, &(*attr)["T"]); + SetAttrValue(op_names, &(*attr)["op_names"]); + + ctx().node_map->AddNode(composition_node->name(), composition_node); + ctx().node_map->AddOutput(NodeName(last_op->input(0)), + composition_node->name()); + + *simplified_node_name = composition_node->name(); + + return Status::OK(); + } + + private: + bool CanOptimize(const NodeDef& node) const { + DataType dtype = GetDataTypeFromAttr(node, "T"); + if (!IsSupported(node.op(), dtype)) { + return false; + } + if (IsInPreserveSet(node)) { + return false; + } + if (!NodeIsOnCpu(node)) { + return false; + } + if (NodeIsAlreadyFused(node)) { + return false; + } + return !(IsDrivenByControlDependency(node) || + DrivesControlDependency(node)); + } + + // UnaryOpsComposition is defined only for CPU. + bool NodeIsOnCpu(const NodeDef& node) const { + using str_util::StartsWith; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) && + StartsWith(device, DEVICE_CPU); + } + + bool NodeIsAlreadyFused(const NodeDef& node) const { + return fused_nodes_.count(node.name()) > 0; + } + + string OptimizedNodeName(const NodeDef& node) const { + return strings::StrCat(node.name(), "/unary_ops_composition"); + } + + void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); } + + // Check if an op is supported by the _UnaryOpsComposition for the given type. + bool IsSupported(const string& op_name, DataType dtype) const { + const auto it = supported_ops_.find(op_name); + return it != supported_ops_.end() && it->second.count(dtype) > 0; + } + + std::unordered_map<string, std::set<DataType>> supported_ops_; + std::unordered_set<string> fused_nodes_; +}; + } // namespace class UniqueNodes { @@ -2056,33 +2912,6 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { return true; } -NodeDef* ArithmeticOptimizer::AddNode(const NodeDef& node, StringPiece suffix, - bool copy_node) { - return AddNode(OptimizedNodeName(node, suffix), copy_node ? &node : nullptr); -} - -NodeDef* ArithmeticOptimizer::AddNode(const string& name, - const NodeDef* node_to_copy) { - NodeDef* new_node = optimized_graph_->add_node(); - node_map_->AddNode(NodeName(name), new_node); - if (node_to_copy != nullptr) { - *new_node = *node_to_copy; - } - new_node->set_name(name); - return new_node; -} - -string ArithmeticOptimizer::OptimizedNodeName(const NodeDef& node, - StringPiece suffix) const { - return AddPrefixToNodeName(strings::StrCat(node.name(), "_", suffix), - kArithmeticOptimizer); -} - -bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node, - StringPiece suffix) const { - return node_map_->NodeExists(OptimizedNodeName(node, suffix)); -} - namespace { bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) { @@ -2206,263 +3035,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( DedupControlInputs(target_node); } -// TODO(ezhulenev): extract each individual simplify rewrite into separate -// ArithmeticOptimizerStage -string ArithmeticOptimizer::TrySimplifyAndReplaceUses( - const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) { - // Fold a multiply of a scalar into the following convolution. This folding - // can jump across nodes that merely reorders data (such as reshape and - // transpose). For example, we can optimize - // - // - // Conv2D - // / \ - // Transpose weights - // | - // Mul - // / \ - // inputs 255.0 - // - // to - // - // Conv2D - // / \ - // Transpose Mul - // | / \ - // | weights 255.0 - // | - // inputs - // - // when `weights` are constant. `Mul` in the optimized graph can be - // constant-folded. - // - // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and - // Conv?DBackpropInput. - if (node->op() == "Conv2D" || node->op() == "Conv3D") { - NodeDef* conv = const_cast<NodeDef*>(node); - const NodeDef* weights = node_map_->GetNode(NodeName(conv->input(1))); - // Fold the multiply to conv only when the weights are constant, so the - // multiply can be constant-folded. TODO(jingyue): When the weights aren't - // constant, this should also help performance a bit and memory usage a lot, - // since the weights tend to be smaller than the activations. - if (weights->op() == "Const" && - !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) { - const NodeDef* source = node_map_->GetNode( - GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_) - ->input(0)); - if (source->op() == "Mul" && - node_map_->GetOutputs(source->name()).size() == 1) { - const NodeDef* mul = source; - // `scale` is the scalar multiplier, and `other` is the other operand. - // TODO(jingyue): handle the case where `scale` is 0-th operand. - const NodeDef* scale = node_map_->GetNode(mul->input(1)); - const NodeDef* other = node_map_->GetNode(mul->input(0)); - if (scale->op() == "Const" && scale->attr().at("dtype").type() == - weights->attr().at("dtype").type()) { - const TensorProto& scale_tensor = scale->attr().at("value").tensor(); - // Test whether `scale` is a scalar. - if (scale_tensor.has_tensor_shape() && - scale_tensor.tensor_shape().dim_size() == 0) { - // Create new node `scaled_weights`. - NodeDef* scaled_weights = AddNode( - *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false); - scaled_weights->set_op("Mul"); - scaled_weights->set_device(weights->device()); - (*scaled_weights->mutable_attr())["T"] = - weights->attr().at("dtype"); - nodes_to_simplify->PushBack(scaled_weights); - - // Link in its inputs. - scaled_weights->add_input(conv->input(1)); - node_map_->AddOutput(weights->name(), scaled_weights->name()); - scaled_weights->add_input(mul->input(1)); - node_map_->AddOutput(scale->name(), scaled_weights->name()); - ForwardControlDependencies(scaled_weights, {source}); - - // Update `conv`'s weights to `scaled_weights`. - conv->set_input(1, scaled_weights->name()); - node_map_->UpdateInput(conv->name(), weights->name(), - scaled_weights->name()); - nodes_to_simplify->PushBack(conv); - - // Update `mul`'s consumer to bypass `mul` because it's folded to - // the weights. - CHECK_EQ(node_map_->GetOutputs(mul->name()).size(), 1); - NodeDef* consumer_of_mul = - *node_map_->GetOutputs(mul->name()).begin(); - consumer_of_mul->set_input(0, mul->input(0)); - node_map_->UpdateInput(consumer_of_mul->name(), mul->name(), - other->name()); - nodes_to_simplify->PushBack(consumer_of_mul); - return conv->name(); - } - } - } - } - } - - if (node->op() == "Mul" && node->input(0) == node->input(1) && - !OptimizedNodeExists(*node, "square")) { - const DataType type = GetDataTypeFromAttr(*node, "T"); - bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - string dontcare; - string device; - bool is_on_cpu = - DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) && - str_util::StrContains(device, DEVICE_CPU); - if (!is_complex || is_on_cpu) { - NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true); - new_square_node->set_op("Square"); - for (int i = 1; i < new_square_node->input_size(); ++i) { - new_square_node->set_input(i - 1, new_square_node->input(i)); - } - new_square_node->mutable_input()->RemoveLast(); - for (const string& input : new_square_node->input()) { - node_map_->AddOutput(NodeName(input), new_square_node->name()); - } - return new_square_node->name(); - } - } - - if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { - // Discard aggregate nodes with a single input and no control dependencies. - if (node->input_size() == 1) { - return node->input(0); - } - - // Try to rewrite aggregations of N >= 2 identical terms (possibly due - // to deduping or other rewrites) so we can get rid of the sum entirely. - // The expression (using AddN as an example of an aggregate op): - // AddN(x, x, x, ... ,x) - // <-- N terms --> - // can be rewritten to - // Mul(Const(N), x)) - // - bool all_equal = true; - int num_inputs = 1; - for (int i = 1; i < node->input_size(); ++i) { - if (IsControlInput(node->input(i))) { - break; - } - ++num_inputs; - if (node->input(i) != node->input(0)) { - all_equal = false; - break; - } - } - if (all_equal && !OptimizedNodeExists(*node, "const") && - !OptimizedNodeExists(*node, "mul")) { - // 1. Create constant node with value N. - const auto type = GetDataTypeFromAttr(*node, "T"); - Tensor t(type, TensorShape({})); - Status status = SetTensorValue(type, num_inputs, &t); - if (!status.ok()) { - LOG(WARNING) << "Failed to create const node: " - << status.error_message(); - return ""; - } - TensorValue value(&t); - NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false); - status = ConstantFolding::CreateNodeDef(new_const_node->name(), value, - new_const_node); - if (!status.ok()) { - LOG(WARNING) << "Failed to create const node: " - << status.error_message(); - return ""; - } - new_const_node->set_device(node->device()); - MaybeAddControlInput(NodeName(node->input(0)), new_const_node, - optimized_graph_, node_map_.get()); - nodes_to_simplify->PushBack(new_const_node); - - // 2. Replace the aggregate node with Mul(Const(N), x). - NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false); - new_mul_node->set_op("Mul"); - new_mul_node->set_device(node->device()); - SetDataTypeToAttr(type, "T", new_mul_node); - new_mul_node->add_input(new_const_node->name()); - node_map_->AddOutput(new_const_node->name(), new_mul_node->name()); - new_mul_node->add_input(node->input(0)); - node_map_->AddOutput(node->input(0), new_mul_node->name()); - - ForwardControlDependencies(new_mul_node, {node}); - return new_mul_node->name(); - } - } - - // Fold Transpose into matrix multiplication. - if ((node->op() == "MatMul" || node->op() == "SparseMatMul" || - node->op() == "BatchMatMul") && - !OptimizedNodeExists(*node, "fused")) { - const NodeDef* a = node_map_->GetNode(node->input(0)); - const NodeDef* b = node_map_->GetNode(node->input(1)); - bool is_complex = false; - if (node->op() != "SparseMatMul") { - const DataType type = GetDataTypeFromAttr(*node, "T"); - is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - } - const std::set<string> foldable_transpose_ops = - !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"} - : (node->op() == "BatchMatMul" - ? std::set<string>{"ConjugateTranspose"} - : std::set<string>{"Transpose"}); - const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 && - IsInnerMatrixTransposeNode(*a, node_map_.get()); - const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 && - IsInnerMatrixTransposeNode(*b, node_map_.get()); - if (a_is_foldable || b_is_foldable) { - NodeDef* new_op = AddNode(*node, "fused", /*copy_node=*/true); - if (a_is_foldable) { - const string attr_a = - node->op() == "BatchMatMul" ? "adj_x" : "transpose_a"; - FlipBooleanAttr(attr_a, new_op); - new_op->set_input(0, a->input(0)); - node_map_->UpdateInput(new_op->name(), a->name(), a->input(0)); - } - if (b_is_foldable) { - const string attr_b = - node->op() == "BatchMatMul" ? "adj_y" : "transpose_b"; - FlipBooleanAttr(attr_b, new_op); - new_op->set_input(1, b->input(0)); - node_map_->UpdateInput(new_op->name(), b->name(), b->input(0)); - } - std::vector<const NodeDef*> deps_to_forward({node}); - if (a_is_foldable) { - deps_to_forward.push_back(a); - } - if (b_is_foldable) { - deps_to_forward.push_back(b); - } - ForwardControlDependencies(new_op, deps_to_forward); - } - } - - // Fold Conj into Transpose or ConjugateTranspose. - if ((node->op() == "Conj" || node->op() == "Transpose" || - node->op() == "ConjugateTranspose") && - !OptimizedNodeExists(*node, "fused")) { - const NodeDef* input = node_map_->GetNode(node->input(0)); - const NodeDef* transpose_op = node->op() == "Conj" ? input : node; - const NodeDef* conj_op = node->op() == "Conj" ? node : input; - - if ((transpose_op->op() == "Transpose" || - transpose_op->op() == "ConjugateTranspose") && - conj_op->op() == "Conj") { - NodeDef* new_op = - AddNode(OptimizedNodeName(*node, "fused"), transpose_op); - // Flip the type of transpose op to absorb the conjugation. - new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose" - : "Transpose"); - new_op->set_input(0, input->input(0)); - node_map_->UpdateInput(new_op->name(), node->name(), input->input(0)); - ForwardControlDependencies(new_op, {node, input}); - return new_op->name(); - } - } - - return ""; -} - Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { SetVector<NodeDef*> nodes_to_simplify; nodes_to_simplify.Reserve(optimized_graph_->node_size()); @@ -2471,7 +3043,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { } const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - graph_properties_.get(), node_map_.get()); + graph_properties_.get(), node_map_.get(), + opt_level_); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); // Stop pipeline after first stage returning non-empty simplified tensor name. @@ -2480,6 +3053,12 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.combine_add_to_addn && can_use_shapes) pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext); + if (options_.fold_conjugate_into_transpose) + pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext); + if (options_.fold_multiply_into_conv) + pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext); + if (options_.fold_transpose_into_matmul) + pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext); if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes) pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext); if (options_.minimize_broadcasts && can_use_shapes) @@ -2496,16 +3075,27 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); + if (options_.replace_mul_with_square) + pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext); if (options_.reorder_cast_and_transpose) pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext); + if (options_.simplify_aggregation) + pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext); if (options_.remove_idempotent) pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext); + if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); + if (options_.convert_log1p) + pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); + if (options_.optimize_max_or_min_of_monotonic) + pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); + if (options_.unary_ops_composition) + pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); @@ -2513,19 +3103,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { while (!nodes_to_simplify.Empty()) { NodeDef* node = nodes_to_simplify.PopBack(); - // TODO(ezhulenev): move all rewrites into separate stages string simplified_tensor = ""; - if (options_.enable_try_simplify_and_replace) { - simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify); - } + bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor); - // if it was not simplified try to run it through all configured stages - if (!stop(simplified_tensor)) { - bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor); - if (!optimized) { - continue; - } - } + // If the node was not optimized by any of the stages, go to the next one. + if (!optimized) continue; // re-wire consumers of an old node to the new one if (NodeName(simplified_tensor) != node->name()) { |