aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc1312
1 files changed, 947 insertions, 365 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index ca3f84a81d..97862d1ed0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -101,38 +101,6 @@ bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
return false;
}
-template <typename T>
-bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
- const T n = perm.size();
- if (n < 2) {
- return false;
- }
- for (T i = 0; i < n - 2; ++i) {
- if (perm[i] != i) {
- return false;
- }
- }
- return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
-}
-
-bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
- const NodeMap* node_map) {
- if (transpose_node.op() != "Transpose" &&
- transpose_node.op() != "ConjugateTranspose") {
- return false;
- }
- const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
- std::vector<int> perm32;
- if (ValuesFromConstNode(*perm_node, &perm32)) {
- return IsInnerMatrixTranspose(perm32);
- }
- std::vector<int64> perm64;
- if (ValuesFromConstNode(*perm_node, &perm64)) {
- return IsInnerMatrixTranspose(perm64);
- }
- return false;
-}
-
bool MaybeAddControlInput(const string& new_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
bool already_exists = false;
@@ -155,12 +123,6 @@ void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
(*node->mutable_attr())[attr_name].set_type(dtype);
}
-void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
- const bool old_value =
- !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
- (*node->mutable_attr())[attr_name].set_b(!old_value);
-}
-
string SourceDataTypeAttrName(const NodeDef& node) {
if (node.op() == "Bitcast") {
return "T";
@@ -265,6 +227,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx().nodes_to_preserve->end();
}
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@@ -395,27 +378,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
is_broadcastable);
}
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
- }
-
string ShapeSignature(const TensorShapeProto& shape) const {
string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
for (int i = 0; i < shape.dim_size(); ++i)
@@ -1122,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
+ // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
+ if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -1757,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
@@ -1958,6 +1919,901 @@ class ReorderCastAndTranspose : public ArithmeticOptimizerStage {
bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
};
+// Fold a multiply of a scalar into the following convolution. This folding
+// can jump across nodes that merely reorders data (such as reshape and
+// transpose). For example, we can optimize
+//
+//
+// Conv2D Conv2D
+// / \ / \
+// Transpose weights* -> Transpose Mul
+// | | / \
+// Mul | weights scale
+// / \ |
+// input scale** input
+//
+// *) weights must be a const
+// **) scale must be a const scalar
+//
+// When `weights` and `scale` are constant, `Mul` in the optimized graph can be
+// constant-folded, also weights tend to be smaller than the activations.
+//
+// TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
+// Conv?DBackpropInput.
+class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
+ ~FoldMultiplyIntoConv() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsConv2D(*node) || IsConv3D(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+#define TF_RETURN_IF_TRUE(...) \
+ if ((__VA_ARGS__)) return Status::OK()
+
+ NodeDef* conv = node;
+
+ NodeDef* weights;
+ TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
+
+ // Fold the multiply to conv only when the weights are constant, so the
+ // multiply can be constant-folded.
+ //
+ // TODO(jingyue): When the weights aren't constant, this should also help
+ // performance a bit and memory usage a lot, since the weights tend to be
+ // smaller than the activations.
+ TF_RETURN_IF_TRUE(!IsConstant(*weights));
+
+ // Verify that this node was not already optimized.
+ const string scaled_weights_node_name =
+ OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
+ strings::StrCat("scaled", "_", conv->name()));
+
+ TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
+
+ // Find the tail of value preserving chain entering the Conv node.
+ NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+
+ NodeDef* source;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
+
+ // Check that value preserving chain is the only consumer of the Mul output.
+ TF_RETURN_IF_TRUE(!IsMul(*source));
+ TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
+
+ const NodeDef* mul = source;
+
+ // TODO(jingyue): handle the case where `scale` is 0-th operand.
+ NodeDef* scale; // scalar multiplier fot the input tensor
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale));
+ TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input));
+
+ // Check that 'scale * weight' can be const folded.
+ TF_RETURN_IF_TRUE(!IsConstant(*scale));
+ TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
+ weights->attr().at("dtype").type());
+
+ // Check that `scale` is a scalar.
+ const TensorProto& scale_tensor = scale->attr().at("value").tensor();
+ bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
+ scale_tensor.tensor_shape().dim_size() == 0;
+ TF_RETURN_IF_TRUE(!scale_is_a_scalar);
+
+ // At this point all preconditions are met, and we safely do the rewrite.
+ VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
+ << " mul=" << mul->name() << " weights=" << weights->name();
+
+ // Create new node `scaled_weights`.
+ NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
+ scaled_weights->set_op("Mul");
+ scaled_weights->set_device(weights->device());
+ (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
+ AddToOptimizationQueue(scaled_weights);
+
+ // Link in its inputs.
+ scaled_weights->add_input(conv->input(1));
+ ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
+ scaled_weights->add_input(mul->input(1));
+ ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
+ ForwardControlDependencies(scaled_weights, {source});
+
+ // Update `conv`'s weights to `scaled_weights`.
+ conv->set_input(1, scaled_weights->name());
+ ctx().node_map->UpdateInput(conv->name(), weights->name(),
+ scaled_weights->name());
+ AddToOptimizationQueue(conv);
+
+ // Update `tail` node to bypass `mul` because it's folded to the weights.
+ tail->set_input(0, mul->input(0));
+ ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
+ AddToOptimizationQueue(tail);
+ *simplified_node_name = conv->name();
+
+ return Status::OK();
+#undef TF_RETURN_IF_TRUE
+ }
+};
+
+// Fold Transpose into matrix multiplication.
+class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
+ ~FoldTransposeIntoMatMul() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMatMul(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(matmul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ NodeDef* a;
+ NodeDef* b;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
+
+ bool is_complex = false;
+ if (node->op() != "SparseMatMul") {
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+ }
+
+ const std::set<string> foldable_transpose_ops =
+ !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
+ : (node->op() == "BatchMatMul"
+ ? std::set<string>{"ConjugateTranspose"}
+ : std::set<string>{"Transpose"});
+
+ const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
+ IsInnerMatrixTransposeNode(*a, ctx().node_map);
+ const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
+ IsInnerMatrixTransposeNode(*b, ctx().node_map);
+ if (!a_is_foldable && !b_is_foldable) return Status::OK();
+
+ NodeDef* new_op = AddCopyNode(optimized_node_name, node);
+
+ if (a_is_foldable) {
+ const string attr_a =
+ node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
+ FlipBooleanAttr(attr_a, new_op);
+ new_op->set_input(0, a->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
+ }
+
+ if (b_is_foldable) {
+ const string attr_b =
+ node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
+ FlipBooleanAttr(attr_b, new_op);
+ new_op->set_input(1, b->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
+ }
+
+ std::vector<const NodeDef*> deps_to_forward = {node};
+ if (a_is_foldable) deps_to_forward.push_back(a);
+ if (b_is_foldable) deps_to_forward.push_back(b);
+ ForwardControlDependencies(new_op, deps_to_forward);
+
+ return Status::OK();
+ }
+
+ private:
+ void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
+ const bool old_value =
+ !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
+ (*node->mutable_attr())[attr_name].set_b(!old_value);
+ }
+
+ template <typename T>
+ bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
+ const T n = perm.size();
+ if (n < 2) {
+ return false;
+ }
+ for (T i = 0; i < n - 2; ++i) {
+ if (perm[i] != i) {
+ return false;
+ }
+ }
+ return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
+ }
+
+ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
+ const NodeMap* node_map) {
+ if (transpose_node.op() != "Transpose" &&
+ transpose_node.op() != "ConjugateTranspose") {
+ return false;
+ }
+ const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
+ std::vector<int> perm32;
+ if (ValuesFromConstNode(*perm_node, &perm32)) {
+ return IsInnerMatrixTranspose(perm32);
+ }
+ std::vector<int64> perm64;
+ if (ValuesFromConstNode(*perm_node, &perm64)) {
+ return IsInnerMatrixTranspose(perm64);
+ }
+ return false;
+ }
+};
+
+// Fold Transpose into matrix multiplication.
+class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
+ ~FoldConjugateIntoTranspose() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(matmul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
+ const NodeDef* conj_op = node->op() == "Conj" ? node : input;
+
+ if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
+ IsConj(*conj_op)) {
+ NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
+
+ // Flip the type of transpose op to absorb the conjugation.
+ new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
+ : "Transpose");
+ new_op->set_input(0, input->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), node->name(),
+ input->input(0));
+ ForwardControlDependencies(new_op, {node, input});
+ *simplified_node_name = new_op->name();
+ }
+
+ return Status::OK();
+ }
+};
+
+// Replace Mul node with identical inputs with a Square.
+class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
+ public:
+ explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
+ ~ReplaceMulWithSquare() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMul(*node) && node->input(0) == node->input(1);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(mul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+
+ string task;
+ string device;
+ bool is_on_cpu =
+ DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ str_util::StrContains(device, DEVICE_CPU);
+
+ if (!is_complex || is_on_cpu) {
+ NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
+ new_square_node->set_op("Square");
+ for (int i = 1; i < new_square_node->input_size(); ++i) {
+ new_square_node->set_input(i - 1, new_square_node->input(i));
+ }
+ new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
+ }
+ *simplified_node_name = new_square_node->name();
+ }
+
+ return Status::OK();
+ }
+};
+
+// Simplify aggregation (e.g. AddN) nodes:
+//
+// 1. Discard aggregate nodes with a single input and no control dependencies.
+//
+// 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
+// deduping or other rewrites) so we can get rid of the sum entirely.
+//
+// The expression (using AddN as an example of an aggregate op):
+// AddN(x, x, x, ... ,x)
+// <-- N terms -->
+// can be rewritten to:
+// Mul(Const(N), x))
+//
+class SimplifyAggregation : public ArithmeticOptimizerStage {
+ public:
+ explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
+ ~SimplifyAggregation() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsAggregate(*node) && NumNonControlInputs(*node) > 0;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ // 1. Discard aggregate nodes with a single input and no control deps.
+ if (node->input_size() == 1) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ // 2. Rewrite aggregations of N >= 2 identical terms.
+
+ // All non-control inputs must be identical.
+ bool all_equal = true;
+ int num_inputs = 1;
+ for (int i = 1; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) break;
+ ++num_inputs;
+ if (node->input(i) != node->input(0)) {
+ all_equal = false;
+ break;
+ }
+ }
+ if (!all_equal) return Status::OK();
+
+ // And node should not be optimized earlier.
+ const NodeScopeAndName node_scope_and_name =
+ ParseNodeScopeAndName(node->name());
+ const string optimized_const_name =
+ OptimizedNodeName(node_scope_and_name, "Const");
+ const string optimized_mul_name =
+ OptimizedNodeName(node_scope_and_name, "Mul");
+
+ bool is_already_optimized =
+ ctx().node_map->NodeExists(optimized_const_name) ||
+ ctx().node_map->NodeExists(optimized_mul_name);
+
+ if (is_already_optimized) return Status::OK();
+
+ // At this point all preconditions are met, and we safely do the rewrite.
+ VLOG(3) << "Simplify aggregation with identical inputs: node="
+ << node->name() << " num_inputs=" << num_inputs;
+
+ // 1. Create constant node with value N.
+ const auto type = GetDataTypeFromAttr(*node, "T");
+ Tensor t(type, TensorShape({}));
+ Status status = SetTensorValue(type, num_inputs, &t);
+ if (!status.ok()) {
+ return errors::Internal("Failed to create const node: ",
+ status.error_message());
+ }
+
+ TensorValue value(&t);
+ NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
+ status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
+ new_const_node);
+ if (!status.ok()) {
+ return errors::Internal("Failed to create const node: ",
+ status.error_message());
+ }
+ new_const_node->set_device(node->device());
+ MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+ ctx().optimized_graph, ctx().node_map);
+ AddToOptimizationQueue(new_const_node);
+
+ // 2. Replace the aggregate node with Mul(Const(N), x).
+ NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
+ new_mul_node->set_op("Mul");
+ new_mul_node->set_device(node->device());
+ SetDataTypeToAttr(type, "T", new_mul_node);
+ new_mul_node->add_input(new_const_node->name());
+ ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
+ new_mul_node->add_input(node->input(0));
+ ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
+
+ ForwardControlDependencies(new_mul_node, {node});
+ *simplified_node_name = new_mul_node->name();
+
+ return Status::OK();
+ }
+};
+
+class ConvertPowStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertPowStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsPow(*node) &&
+ ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < p.shape().dim_size(); ++i) {
+ if (p.shape().dim(i).size() < 0) {
+ // skip if p is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor pow(p.dtype(), p.shape());
+ if (!pow.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+
+ complex128 prev, curr;
+ for (int i = 0; i < pow.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (i != 0 && curr != prev) {
+ // pow has different values on different elements. Skip.
+ return Status::OK();
+ }
+ prev = curr;
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ if (curr == complex128(2, 0)) {
+ node->set_op("Square");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(1, 0)) {
+ node->set_op("Identity");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0.5, 0)) {
+ node->set_op("Sqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0, 0)) {
+ const auto& b =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ for (int i = 0; i < b.shape().dim_size(); ++i) {
+ if (b.shape().dim(i).size() < 0) {
+ // skip if b is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(b.shape()) && b.has_value()) {
+ Tensor base(b.dtype(), b.shape());
+ if (!base.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
+ }
+ node->set_op("Const");
+ Tensor c(base.dtype(), base.shape());
+ for (int i = 0; i < c.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
+ }
+ (*node->mutable_attr())["dtype"].set_type(base.dtype());
+ c.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+ node->mutable_attr()->erase("T");
+ node->set_input(0, AsControlDependency(x->name()));
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ } else if (curr == complex128(-0.5, 0)) {
+ node->set_op("Rsqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(-1, 0)) {
+ node->set_op("Reciprocal");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
+ Status SetElementToOne(int i, Tensor* t) {
+ switch (t->dtype()) {
+ case DT_INT32:
+ t->flat<int32>()(i) = 1;
+ return Status::OK();
+ case DT_INT64:
+ t->flat<int64>()(i) = 1L;
+ return Status::OK();
+ case DT_FLOAT:
+ t->flat<float>()(i) = 1.0f;
+ return Status::OK();
+ case DT_DOUBLE:
+ t->flat<double>()(i) = 1.0;
+ return Status::OK();
+ case DT_COMPLEX64:
+ t->flat<complex64>()(i) = complex64(1);
+ return Status::OK();
+ case DT_COMPLEX128:
+ t->flat<complex128>()(i) = complex128(1);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t->dtype());
+ }
+ }
+};
+
+class ConvertLog1pStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
+ ~ConvertLog1pStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsAdd(*input)) {
+ return Status::OK();
+ }
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ bool modified = false;
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
+ if (!modified) {
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
+ }
+ if (modified) {
+ *simplified_node_name = node->name();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
+ bool* modified) {
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[i];
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ for (int k = 0; k < c.shape().dim_size(); ++k) {
+ // Skip if c shape is not fully determined.
+ if (c.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return Status::OK();
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(c.shape()) && c.has_value()) {
+ Tensor constant(c.dtype(), c.shape());
+ if (!constant.FromProto(c.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ c.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < constant.NumElements(); ++k) {
+ if (!GetElement(constant, k, &element)) {
+ // input data type is not supported by log1p. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
+ node->set_op("Log1p");
+ node->set_input(0, input->input(i));
+ node->add_input(AsControlDependency(y->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(input);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ *modified = true;
+ }
+ return Status::OK();
+ }
+
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
+// Replace a chain of type&shape preserving unary ops with a
+// '_UnaryOpsComposition' node.
+// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
+// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
+class UnaryOpsComposition : public ArithmeticOptimizerStage {
+ public:
+ explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
+ // WARN: This should be consistent with unary_ops_composition.cc.
+ // clang-format off
+ supported_ops_ = {// Ops defined via Eigen scalar ops.
+ {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Acos", {DT_FLOAT, DT_DOUBLE}},
+ {"Acosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Asin", {DT_FLOAT, DT_DOUBLE}},
+ {"Asinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Atan", {DT_FLOAT, DT_DOUBLE}},
+ {"Atanh", {DT_FLOAT, DT_DOUBLE}},
+ {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rint", {DT_FLOAT, DT_DOUBLE}},
+ {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Tan", {DT_FLOAT, DT_DOUBLE}},
+ {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ // Additional ops that are not part of the Eigen.
+ {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
+ // clang-format on
+ }
+ ~UnaryOpsComposition() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return CanOptimize(*node) &&
+ // Check that this node was not already a root of a fused chain. If
+ // graph optimization runs twice without pruning in between,
+ // fused_nodes_ will not have this information.
+ !ctx().node_map->NodeExists(OptimizedNodeName(*node));
+ }
+
+ Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
+ DataType dtype = root->attr().at("T").type();
+
+ // Keep a trace of all supported input nodes that can be fused together.
+ std::vector<string> op_nodes = {root->name()};
+ std::vector<string> op_names = {root->op()};
+
+ // Check if we should follow input(0) while building an op composition.
+ const auto predicate_fn = [&](const NodeDef& input) {
+ if (input.name() == root->name()) return true;
+
+ bool follow_input_node =
+ dtype == GetDataTypeFromAttr(input, "T") &&
+ NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
+ CanOptimize(input);
+
+ if (follow_input_node) {
+ op_nodes.push_back(input.name());
+ op_names.push_back(input.op());
+ }
+
+ return follow_input_node;
+ };
+
+ NodeDef* last_op = GetTailOfChain(
+ *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
+
+ // We were not able to find a chain that can be replaced.
+ if (op_names.size() == 1) return Status::OK();
+
+ // Do not add fused nodes to any other chain.
+ std::for_each(op_nodes.begin(), op_nodes.end(),
+ [this](const string& name) { AddToFusedNodes(name); });
+
+ // Reverse the trace to get correct composition computation order.
+ std::reverse(op_names.begin(), op_names.end());
+
+ VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
+ << str_util::Join(op_names, ", ") << "]";
+
+ NodeDef* composition_node = ctx().optimized_graph->add_node();
+ composition_node->set_name(OptimizedNodeName(*root));
+ composition_node->set_op("_UnaryOpsComposition");
+ composition_node->add_input(last_op->input(0));
+ composition_node->set_device(root->device());
+
+ auto attr = composition_node->mutable_attr();
+ SetAttrValue(dtype, &(*attr)["T"]);
+ SetAttrValue(op_names, &(*attr)["op_names"]);
+
+ ctx().node_map->AddNode(composition_node->name(), composition_node);
+ ctx().node_map->AddOutput(NodeName(last_op->input(0)),
+ composition_node->name());
+
+ *simplified_node_name = composition_node->name();
+
+ return Status::OK();
+ }
+
+ private:
+ bool CanOptimize(const NodeDef& node) const {
+ DataType dtype = GetDataTypeFromAttr(node, "T");
+ if (!IsSupported(node.op(), dtype)) {
+ return false;
+ }
+ if (IsInPreserveSet(node)) {
+ return false;
+ }
+ if (!NodeIsOnCpu(node)) {
+ return false;
+ }
+ if (NodeIsAlreadyFused(node)) {
+ return false;
+ }
+ return !(IsDrivenByControlDependency(node) ||
+ DrivesControlDependency(node));
+ }
+
+ // UnaryOpsComposition is defined only for CPU.
+ bool NodeIsOnCpu(const NodeDef& node) const {
+ using str_util::StartsWith;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
+ StartsWith(device, DEVICE_CPU);
+ }
+
+ bool NodeIsAlreadyFused(const NodeDef& node) const {
+ return fused_nodes_.count(node.name()) > 0;
+ }
+
+ string OptimizedNodeName(const NodeDef& node) const {
+ return strings::StrCat(node.name(), "/unary_ops_composition");
+ }
+
+ void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
+
+ // Check if an op is supported by the _UnaryOpsComposition for the given type.
+ bool IsSupported(const string& op_name, DataType dtype) const {
+ const auto it = supported_ops_.find(op_name);
+ return it != supported_ops_.end() && it->second.count(dtype) > 0;
+ }
+
+ std::unordered_map<string, std::set<DataType>> supported_ops_;
+ std::unordered_set<string> fused_nodes_;
+};
+
} // namespace
class UniqueNodes {
@@ -2056,33 +2912,6 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
return true;
}
-NodeDef* ArithmeticOptimizer::AddNode(const NodeDef& node, StringPiece suffix,
- bool copy_node) {
- return AddNode(OptimizedNodeName(node, suffix), copy_node ? &node : nullptr);
-}
-
-NodeDef* ArithmeticOptimizer::AddNode(const string& name,
- const NodeDef* node_to_copy) {
- NodeDef* new_node = optimized_graph_->add_node();
- node_map_->AddNode(NodeName(name), new_node);
- if (node_to_copy != nullptr) {
- *new_node = *node_to_copy;
- }
- new_node->set_name(name);
- return new_node;
-}
-
-string ArithmeticOptimizer::OptimizedNodeName(const NodeDef& node,
- StringPiece suffix) const {
- return AddPrefixToNodeName(strings::StrCat(node.name(), "_", suffix),
- kArithmeticOptimizer);
-}
-
-bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node,
- StringPiece suffix) const {
- return node_map_->NodeExists(OptimizedNodeName(node, suffix));
-}
-
namespace {
bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
@@ -2206,263 +3035,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
DedupControlInputs(target_node);
}
-// TODO(ezhulenev): extract each individual simplify rewrite into separate
-// ArithmeticOptimizerStage
-string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
- const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- // Fold a multiply of a scalar into the following convolution. This folding
- // can jump across nodes that merely reorders data (such as reshape and
- // transpose). For example, we can optimize
- //
- //
- // Conv2D
- // / \
- // Transpose weights
- // |
- // Mul
- // / \
- // inputs 255.0
- //
- // to
- //
- // Conv2D
- // / \
- // Transpose Mul
- // | / \
- // | weights 255.0
- // |
- // inputs
- //
- // when `weights` are constant. `Mul` in the optimized graph can be
- // constant-folded.
- //
- // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
- // Conv?DBackpropInput.
- if (node->op() == "Conv2D" || node->op() == "Conv3D") {
- NodeDef* conv = const_cast<NodeDef*>(node);
- const NodeDef* weights = node_map_->GetNode(NodeName(conv->input(1)));
- // Fold the multiply to conv only when the weights are constant, so the
- // multiply can be constant-folded. TODO(jingyue): When the weights aren't
- // constant, this should also help performance a bit and memory usage a lot,
- // since the weights tend to be smaller than the activations.
- if (weights->op() == "Const" &&
- !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) {
- const NodeDef* source = node_map_->GetNode(
- GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_)
- ->input(0));
- if (source->op() == "Mul" &&
- node_map_->GetOutputs(source->name()).size() == 1) {
- const NodeDef* mul = source;
- // `scale` is the scalar multiplier, and `other` is the other operand.
- // TODO(jingyue): handle the case where `scale` is 0-th operand.
- const NodeDef* scale = node_map_->GetNode(mul->input(1));
- const NodeDef* other = node_map_->GetNode(mul->input(0));
- if (scale->op() == "Const" && scale->attr().at("dtype").type() ==
- weights->attr().at("dtype").type()) {
- const TensorProto& scale_tensor = scale->attr().at("value").tensor();
- // Test whether `scale` is a scalar.
- if (scale_tensor.has_tensor_shape() &&
- scale_tensor.tensor_shape().dim_size() == 0) {
- // Create new node `scaled_weights`.
- NodeDef* scaled_weights = AddNode(
- *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false);
- scaled_weights->set_op("Mul");
- scaled_weights->set_device(weights->device());
- (*scaled_weights->mutable_attr())["T"] =
- weights->attr().at("dtype");
- nodes_to_simplify->PushBack(scaled_weights);
-
- // Link in its inputs.
- scaled_weights->add_input(conv->input(1));
- node_map_->AddOutput(weights->name(), scaled_weights->name());
- scaled_weights->add_input(mul->input(1));
- node_map_->AddOutput(scale->name(), scaled_weights->name());
- ForwardControlDependencies(scaled_weights, {source});
-
- // Update `conv`'s weights to `scaled_weights`.
- conv->set_input(1, scaled_weights->name());
- node_map_->UpdateInput(conv->name(), weights->name(),
- scaled_weights->name());
- nodes_to_simplify->PushBack(conv);
-
- // Update `mul`'s consumer to bypass `mul` because it's folded to
- // the weights.
- CHECK_EQ(node_map_->GetOutputs(mul->name()).size(), 1);
- NodeDef* consumer_of_mul =
- *node_map_->GetOutputs(mul->name()).begin();
- consumer_of_mul->set_input(0, mul->input(0));
- node_map_->UpdateInput(consumer_of_mul->name(), mul->name(),
- other->name());
- nodes_to_simplify->PushBack(consumer_of_mul);
- return conv->name();
- }
- }
- }
- }
- }
-
- if (node->op() == "Mul" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(*node, "square")) {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- string dontcare;
- string device;
- bool is_on_cpu =
- DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) &&
- str_util::StrContains(device, DEVICE_CPU);
- if (!is_complex || is_on_cpu) {
- NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true);
- new_square_node->set_op("Square");
- for (int i = 1; i < new_square_node->input_size(); ++i) {
- new_square_node->set_input(i - 1, new_square_node->input(i));
- }
- new_square_node->mutable_input()->RemoveLast();
- for (const string& input : new_square_node->input()) {
- node_map_->AddOutput(NodeName(input), new_square_node->name());
- }
- return new_square_node->name();
- }
- }
-
- if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
- // Discard aggregate nodes with a single input and no control dependencies.
- if (node->input_size() == 1) {
- return node->input(0);
- }
-
- // Try to rewrite aggregations of N >= 2 identical terms (possibly due
- // to deduping or other rewrites) so we can get rid of the sum entirely.
- // The expression (using AddN as an example of an aggregate op):
- // AddN(x, x, x, ... ,x)
- // <-- N terms -->
- // can be rewritten to
- // Mul(Const(N), x))
- //
- bool all_equal = true;
- int num_inputs = 1;
- for (int i = 1; i < node->input_size(); ++i) {
- if (IsControlInput(node->input(i))) {
- break;
- }
- ++num_inputs;
- if (node->input(i) != node->input(0)) {
- all_equal = false;
- break;
- }
- }
- if (all_equal && !OptimizedNodeExists(*node, "const") &&
- !OptimizedNodeExists(*node, "mul")) {
- // 1. Create constant node with value N.
- const auto type = GetDataTypeFromAttr(*node, "T");
- Tensor t(type, TensorShape({}));
- Status status = SetTensorValue(type, num_inputs, &t);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create const node: "
- << status.error_message();
- return "";
- }
- TensorValue value(&t);
- NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false);
- status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
- new_const_node);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create const node: "
- << status.error_message();
- return "";
- }
- new_const_node->set_device(node->device());
- MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
- optimized_graph_, node_map_.get());
- nodes_to_simplify->PushBack(new_const_node);
-
- // 2. Replace the aggregate node with Mul(Const(N), x).
- NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false);
- new_mul_node->set_op("Mul");
- new_mul_node->set_device(node->device());
- SetDataTypeToAttr(type, "T", new_mul_node);
- new_mul_node->add_input(new_const_node->name());
- node_map_->AddOutput(new_const_node->name(), new_mul_node->name());
- new_mul_node->add_input(node->input(0));
- node_map_->AddOutput(node->input(0), new_mul_node->name());
-
- ForwardControlDependencies(new_mul_node, {node});
- return new_mul_node->name();
- }
- }
-
- // Fold Transpose into matrix multiplication.
- if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
- node->op() == "BatchMatMul") &&
- !OptimizedNodeExists(*node, "fused")) {
- const NodeDef* a = node_map_->GetNode(node->input(0));
- const NodeDef* b = node_map_->GetNode(node->input(1));
- bool is_complex = false;
- if (node->op() != "SparseMatMul") {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- }
- const std::set<string> foldable_transpose_ops =
- !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
- : (node->op() == "BatchMatMul"
- ? std::set<string>{"ConjugateTranspose"}
- : std::set<string>{"Transpose"});
- const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
- IsInnerMatrixTransposeNode(*a, node_map_.get());
- const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
- IsInnerMatrixTransposeNode(*b, node_map_.get());
- if (a_is_foldable || b_is_foldable) {
- NodeDef* new_op = AddNode(*node, "fused", /*copy_node=*/true);
- if (a_is_foldable) {
- const string attr_a =
- node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
- FlipBooleanAttr(attr_a, new_op);
- new_op->set_input(0, a->input(0));
- node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
- }
- if (b_is_foldable) {
- const string attr_b =
- node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
- FlipBooleanAttr(attr_b, new_op);
- new_op->set_input(1, b->input(0));
- node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
- }
- std::vector<const NodeDef*> deps_to_forward({node});
- if (a_is_foldable) {
- deps_to_forward.push_back(a);
- }
- if (b_is_foldable) {
- deps_to_forward.push_back(b);
- }
- ForwardControlDependencies(new_op, deps_to_forward);
- }
- }
-
- // Fold Conj into Transpose or ConjugateTranspose.
- if ((node->op() == "Conj" || node->op() == "Transpose" ||
- node->op() == "ConjugateTranspose") &&
- !OptimizedNodeExists(*node, "fused")) {
- const NodeDef* input = node_map_->GetNode(node->input(0));
- const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
- const NodeDef* conj_op = node->op() == "Conj" ? node : input;
-
- if ((transpose_op->op() == "Transpose" ||
- transpose_op->op() == "ConjugateTranspose") &&
- conj_op->op() == "Conj") {
- NodeDef* new_op =
- AddNode(OptimizedNodeName(*node, "fused"), transpose_op);
- // Flip the type of transpose op to absorb the conjugation.
- new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
- : "Transpose");
- new_op->set_input(0, input->input(0));
- node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
- ForwardControlDependencies(new_op, {node, input});
- return new_op->name();
- }
- }
-
- return "";
-}
-
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
SetVector<NodeDef*> nodes_to_simplify;
nodes_to_simplify.Reserve(optimized_graph_->node_size());
@@ -2471,7 +3043,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get());
+ graph_properties_.get(), node_map_.get(),
+ opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -2480,6 +3053,12 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.combine_add_to_addn && can_use_shapes)
pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
+ if (options_.fold_conjugate_into_transpose)
+ pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
+ if (options_.fold_multiply_into_conv)
+ pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
+ if (options_.fold_transpose_into_matmul)
+ pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
if (options_.minimize_broadcasts && can_use_shapes)
@@ -2496,16 +3075,27 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+ if (options_.replace_mul_with_square)
+ pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_and_transpose)
pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext);
+ if (options_.simplify_aggregation)
+ pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
if (options_.hoist_cwise_unary_chains)
pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
if (options_.convert_sqrt_div_to_rsqrt_mul)
pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
+ if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
+ if (options_.convert_log1p)
+ pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
+ if (options_.unary_ops_composition)
+ pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
@@ -2513,19 +3103,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
while (!nodes_to_simplify.Empty()) {
NodeDef* node = nodes_to_simplify.PopBack();
- // TODO(ezhulenev): move all rewrites into separate stages
string simplified_tensor = "";
- if (options_.enable_try_simplify_and_replace) {
- simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
- }
+ bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
- // if it was not simplified try to run it through all configured stages
- if (!stop(simplified_tensor)) {
- bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
- if (!optimized) {
- continue;
- }
- }
+ // If the node was not optimized by any of the stages, go to the next one.
+ if (!optimized) continue;
// re-wire consumers of an old node to the new one
if (NodeName(simplified_tensor) != node->name()) {