aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-10 11:09:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 11:12:23 -0700
commitafa17984849881f39fb56c6e3500d866539924d5 (patch)
treefc1d4fddfcde926ad23f2fede369c30715973d47 /tensorflow
parentc276b8314cd3161c5626d845edcfb6697cefd043 (diff)
Adds support for hoisting out common denominator in arithmetic_optimizer
PiperOrigin-RevId: 192314177
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc103
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc85
2 files changed, 161 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index fa0f7c1c6e..463c332858 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -695,15 +695,20 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
}
};
-// Use the commutativity and (left- and right-) distributive property of
-// multiplication over addition to hoist common factors out of aggregate nodes
-// where all the inputs are Mul nodes. This pattern occurs frequently in
-// regularization terms for the gradients during training.
+// Use the distributive property of multiplication and division over addition,
+// along with commutativity of the former, to hoist common factors/denominators
+// out of aggregate nodes where ALL the inputs are Mul/Div nodes.
+// This pattern occurs frequently in regularization terms for the gradients
+// during training.
//
// For example, we can rewrite an expression of the form:
// AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
// to the following:
// Mul(x, AddN(y1, y2, y3, ... yn))
+// For division, we can rewrite
+// AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
+// to:
+// Div(AddN(y1, y2, y3, ... yn), x)
class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
public:
explicit HoistCommonFactorOutOfAggregation(
@@ -720,9 +725,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+ bool common_factor_is_denominator = false;
std::set<string> common_factors;
std::vector<string> ctrl_deps;
- TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
+ TF_RETURN_IF_ERROR(GetCommonFactors(
+ node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
@@ -730,24 +737,31 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
// Gather up the non-shared factors
bool shapes_match = true;
std::vector<string> unique_factors;
- TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match,
- &unique_factors));
+ TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
+ common_factor_is_denominator,
+ &shapes_match, &unique_factors));
if (shapes_match) {
NodeDef* input_0;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
- // Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0);
+ // Use a copy of the first node for the outer multiplication/division.
+ NodeDef* new_outer_node = AddCopyNode(
+ OuterNodeName(node, common_factor_is_denominator), input_0);
// And a copy of aggregation node as one of the inner operands
NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
- new_mul_node->set_device(node->device());
- new_mul_node->set_input(0, common_factor);
- new_mul_node->set_input(1, new_add_node->name());
+ new_outer_node->set_device(node->device());
+ if (common_factor_is_denominator) {
+ new_outer_node->set_input(0, new_add_node->name());
+ new_outer_node->set_input(1, common_factor);
+ } else {
+ new_outer_node->set_input(0, common_factor);
+ new_outer_node->set_input(1, new_add_node->name());
+ }
- ctx_.node_map->AddOutput(common_factor, new_mul_node->name());
- ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name());
+ ctx_.node_map->AddOutput(common_factor, new_outer_node->name());
+ ctx_.node_map->AddOutput(new_add_node->name(), new_outer_node->name());
// Hoist non-shared factors up into the new AddN node.
for (int i = 0; i < unique_factors.size(); ++i) {
@@ -766,17 +780,18 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
AddToOptimizationQueue(new_add_node);
// do not optimize the same node twice
rewritten_nodes_.insert(node->name());
- *simplified_node_name = new_mul_node->name();
+ *simplified_node_name = new_outer_node->name();
}
}
return Status::OK();
}
private:
- // Get a name for new outer Mul node
- string OuterMulNodeName(const NodeDef* node) const {
+ // Get a name for new outer node
+ string OuterNodeName(const NodeDef* node, bool is_div) const {
auto scope_and_name = ParseNodeScopeAndName(node->name());
- return OptimizedNodeName(scope_and_name, "Mul");
+ return is_div ? OptimizedNodeName(scope_and_name, "Div")
+ : OptimizedNodeName(scope_and_name, "Mul");
}
// Get a name new inner Add node
@@ -785,11 +800,17 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
return OptimizedNodeName(scope_and_name, "Add");
}
- // Determine the set of common factors if the input nodes are all Mul nodes.
+ // Determine the set of common factors if the input nodes are all Mul or
+ // Div nodes.
Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+ bool* common_factor_is_denominator,
std::vector<string>* ctrl_deps) const {
CHECK(common_factors->empty());
+ CHECK_NOTNULL(common_factor_is_denominator);
+ *common_factor_is_denominator = false;
+ bool has_mul = false;
+ bool has_div = false;
for (int i = 0; i < node->input_size(); ++i) {
if (i > 0 && common_factors->empty()) break;
if (IsControlInput(node->input(i))) {
@@ -799,12 +820,36 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
- if (!IsMul(*input)) {
+ if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
+ (IsAnyDiv(*input) && has_mul)) {
+ // Break if input is neither a Mul or Div, or if there are both Mul &
+ // Div Ops.
common_factors->clear();
break;
+ } else if (IsAnyDiv(*input)) {
+ has_div = true;
+ // In case of possible common dividers, we avoid hoisting out if any
+ // input is not float/double, since integer division is not distributive
+ // over addition.
+ OpInfo::TensorProperties properties0, properties1;
+ TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
+ TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
+ if (properties0.dtype() != DT_FLOAT &&
+ properties0.dtype() != DT_DOUBLE &&
+ properties1.dtype() != DT_FLOAT &&
+ properties1.dtype() != DT_DOUBLE) {
+ common_factors->clear();
+ break;
+ }
+ } else if (IsMul(*input)) {
+ has_mul = true;
}
- std::set<string> factors_i{input->input(0), input->input(1)};
+ // We only focus on common factors from denominators if any Op is a
+ // Div.
+ std::set<string> factors_i =
+ has_mul ? std::set<string>{input->input(0), input->input(1)}
+ : std::set<string>{input->input(1)};
if (i == 0) {
std::swap(*common_factors, factors_i);
} else {
@@ -819,6 +864,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
ctrl_deps->push_back(input->input(i));
}
}
+
+ *common_factor_is_denominator = has_div;
return Status::OK();
}
@@ -827,6 +874,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
// have the same shape since the other aggregation ops do not support
// broadcasting.
Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
+ const bool common_factor_is_denominator,
bool* shapes_match,
std::vector<string>* unique_factors) const {
*shapes_match = true;
@@ -837,11 +885,13 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
if (IsControlInput(input)) {
break;
}
- NodeDef* mul_node;
- TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node));
+ NodeDef* inner_node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
const int unique_factor_index =
- mul_node->input(0) == common_factor ? 1 : 0;
- unique_factors->push_back(mul_node->input(unique_factor_index));
+ common_factor_is_denominator
+ ? 0
+ : (inner_node->input(0) == common_factor ? 1 : 0);
+ unique_factors->push_back(inner_node->input(unique_factor_index));
if (i > 0 && !IsAdd(*node)) {
OpInfo::TensorProperties lhs;
OpInfo::TensorProperties rhs;
@@ -857,7 +907,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
// if graph rewrite happens in multiple passes without graph pruning between
// them, it's possible that rewritten node already exists in a graph
return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
- ctx_.node_map->NodeExists(OuterMulNodeName(node));
+ ctx_.node_map->NodeExists(OuterNodeName(node, false)) ||
+ ctx_.node_map->NodeExists(OuterNodeName(node, true));
}
// keep names of the nodes that were optimized by this stage
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 9677175d2e..e639812858 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -31,6 +31,9 @@ namespace grappler {
namespace {
+constexpr char kHoistFactorOptimizerDiv[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Div_";
+
constexpr char kHoistFactorOptimizerMul[] =
"ArithmeticOptimizer/HoistCommonFactor_Mul_";
@@ -42,6 +45,11 @@ string HoistMulName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
}
+// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation
+string HoistDivName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
+}
+
// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
string HoistAddName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
@@ -558,7 +566,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
}
-TEST_F(ArithmeticOptimizerTest, HoistFactor) {
+TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
for (bool matching_shapes : {true, false}) {
for (bool use_addn : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -625,6 +633,81 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
}
}
+TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
+ for (bool matching_shapes : {true, false}) {
+ for (bool use_addn : {true, false}) {
+ for (bool use_ints : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = use_ints
+ ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
+ : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output y1 = use_ints
+ ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
+ : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+ Output y2;
+ if (matching_shapes) {
+ y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
+ : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
+ } else {
+ y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
+ : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
+ }
+ Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
+ Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
+ Output id =
+ use_addn
+ ? ops::Identity(s.WithOpName("id"),
+ ops::AddN(s.WithOpName("add"), {div1, div2}))
+ : ops::Identity(s.WithOpName("id"),
+ ops::Add(s.WithOpName("add"), div1, div2));
+
+ GrapplerItem item;
+ item.fetch = {"id"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCommonFactor(&optimizer);
+
+ GraphDef output;
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // Add Div
+ // / \ / \
+ // Div Div -> Add x
+ // / \ / \ / \
+ // y1 x y2 x y1 y2
+ //
+ // If "root" op is AddN and shapes does not match, this rewrite is not
+ // possible and graph should stay intact.
+ NodeMap node_map(&output);
+
+ if ((use_addn && !matching_shapes) || use_ints) {
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+ } else {
+ EXPECT_EQ(9, output.node_size());
+
+ const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+ ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ EXPECT_EQ("y1", new_add_node->input(0));
+ EXPECT_EQ("y2", new_add_node->input(1));
+
+ const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
+ ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
+ EXPECT_EQ(new_add_node->name(), new_div_node->input(0));
+ EXPECT_EQ("x", new_div_node->input(1));
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ EXPECT_EQ("id", id_node->name());
+ EXPECT_EQ(HoistDivName("add"), id_node->input(0));
+ }
+ }
+ }
+ }
+}
+
TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});