aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-02 16:32:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 16:40:06 -0700
commit6ace5e0494d8142dc67ca0714893afc716125917 (patch)
tree21bf67f21d8318b66b2cfea4cc65d83e3cc9b66b
parent3a8eaaf6a238e238a7adac9886b1569d7e43ae23 (diff)
* Add optimization to hoist a common factor out of sums of products involving aggregate ops (AddN, Add, Accumulate) or eliminate the aggregation op entirely.
* Replace trivial aggregations of the form x+x+x... with const(N)*x for N > 1. PiperOrigin-RevId: 174398543
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc247
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc98
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h7
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/utils/BUILD1
-rw-r--r--tensorflow/core/grappler/utils/frame.cc28
-rw-r--r--tensorflow/core/grappler/utils/frame.h14
-rw-r--r--tensorflow/core/grappler/utils/frame_test.cc12
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py4
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py3
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py3
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py5
-rw-r--r--tensorflow/python/debug/lib/stepper_test.py3
16 files changed, 390 insertions, 67 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 681d26e262..669d02815c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -161,6 +161,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":constant_folding",
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -170,6 +171,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 78b55237d1..445e5cf972 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
+
+#include <algorithm>
+#include <limits>
#include <unordered_map>
#include <unordered_set>
+
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -23,6 +27,9 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/utils/frame.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -31,6 +38,45 @@ namespace tensorflow {
namespace grappler {
namespace {
+template <typename T>
+bool SafeSetTensorValue(double value, Tensor* tensor) {
+ using RealType = typename Eigen::NumTraits<T>::Real;
+ if (value > std::numeric_limits<RealType>::max() ||
+ value < std::numeric_limits<RealType>::min()) {
+ return false;
+ }
+ tensor->flat<T>()(0) = static_cast<T>(value);
+ return true;
+}
+
+#define HANDLE_CASE(DTYPE) \
+ case DTYPE: \
+ if (!SafeSetTensorValue<EnumToDataType<DTYPE>::Type>( \
+ static_cast<double>(value), tensor)) { \
+ return errors::InvalidArgument("Cannot store value ", value, \
+ " in tensor of type " #DTYPE); \
+ } \
+ break
+
+Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
+ switch (dtype) {
+ // HANDLE_CASE(DT_HALF);
+ HANDLE_CASE(DT_FLOAT);
+ HANDLE_CASE(DT_DOUBLE);
+ HANDLE_CASE(DT_UINT8);
+ HANDLE_CASE(DT_INT8);
+ HANDLE_CASE(DT_UINT16);
+ HANDLE_CASE(DT_INT16);
+ HANDLE_CASE(DT_INT32);
+ HANDLE_CASE(DT_INT64);
+ HANDLE_CASE(DT_COMPLEX64);
+ HANDLE_CASE(DT_COMPLEX128);
+ default:
+ return errors::InvalidArgument("Unexpected type ", DataTypeString(dtype));
+ }
+ return Status::OK();
+}
+
static bool IsInvolution(const NodeDef& node) {
const std::unordered_set<string> involution_ops = {"Conj", "Reciprocal",
"Neg", "LogicalNot"};
@@ -107,14 +153,28 @@ DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
return attr.type();
}
-bool IsCommutative(const OpDef& op, const NodeDef& input1) {
- if (op.name() == "Add") {
+bool IsCommutative(const NodeDef& node) {
+ if (node.op() == "Add" && node.input_size() > 0) {
// Workaround for "Add" not being marked is_commutative and is_aggregate.
// (See cl/173915048).
- const auto type = GetDataTypeFromAttr(input1, "T");
+ const auto type = GetDataTypeFromAttr(node, "T");
return type != DT_INVALID && type != DT_STRING;
}
- return op.is_commutative();
+ const OpDef* op_def = nullptr;
+ const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
+ return status.ok() && op_def->is_commutative();
+}
+
+bool IsAggregate(const NodeDef& node) {
+ if (node.op() == "Add" && node.input_size() > 0) {
+ // Workaround for "Add" not being marked is_commutative and is_aggregate.
+ // (See cl/173915048).
+ const auto type = GetDataTypeFromAttr(node, "T");
+ return type != DT_INVALID && type != DT_STRING;
+ }
+ const OpDef* op_def = nullptr;
+ const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
+ return status.ok() && op_def->is_aggregate();
}
void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
@@ -208,6 +268,30 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
return true;
}
+// Fix frame dependencies by adding control dependencies from old_input to nodes
+// in new_nodes_for_control_dep, and update frame_map for all nodes in
+// new_nodes.
+void AddFrameControlDeps(const NodeDef* old_node,
+ const std::vector<NodeDef*>& new_nodes,
+ const string& source_for_ctrl_dep,
+ const std::vector<NodeDef*>& sinks_for_control_dep,
+ GraphDef* graph, NodeMap* node_map,
+ FrameMap* frame_map) {
+ const auto frame_it = frame_map->find(old_node);
+ if (frame_it != frame_map->end()) {
+ for (auto node : new_nodes) {
+ frame_map->emplace(node, frame_it->second);
+ }
+ if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ source_for_ctrl_dep, graph, node_map);
+ for (auto node : sinks_for_control_dep) {
+ node->add_input(ctrl_dep);
+ }
+ }
+ }
+}
+
} // namespace
class UniqueNodes {
@@ -264,10 +348,7 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
}
// Compare inputs.
- const OpDef* op_def = nullptr;
- Status status = OpRegistry::Global()->LookUpOpDef(node1.op(), &op_def);
- const bool is_commutative = status.ok() && IsCommutative(*op_def, node1);
- if (is_commutative) {
+ if (IsCommutative(node1)) {
std::vector<string> inputs1(node1.input().begin(), node1.input().end());
std::vector<string> inputs2(node2.input().begin(), node2.input().end());
std::sort(inputs1.begin(), inputs1.end());
@@ -399,7 +480,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, GraphDef* graph_def, NodeMap* node_map,
- std::vector<const NodeDef*>* new_nodes) const {
+ std::vector<const NodeDef*>* new_nodes, FrameMap* frame_map) const {
// Remove involutions applied twice.
if (IsInvolution(*node)) {
// An involution is a function f(x) that is its own inverse,
@@ -519,6 +600,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_nodes->push_back(new_transpose);
new_nodes->push_back(new_cast);
+ // Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_transpose, new_cast},
+ new_transpose->input(0), {new_transpose},
+ graph_def, node_map, frame_map);
+
return new_cast->name();
}
}
@@ -625,6 +711,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map->AddOutput(weights->name(), scaled_weights->name());
scaled_weights->add_input(mul->input(1));
node_map->AddOutput(scale->name(), scaled_weights->name());
+ AddFrameControlDeps(node, {scaled_weights}, "", {}, graph_def,
+ node_map, frame_map);
// Update `conv`'s weights to `scaled_weights`.
conv->set_input(1, scaled_weights->name());
@@ -648,6 +736,134 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
+ if (node->input_size() > 0 && IsAggregate(*node) &&
+ !node_map->GetOutputs(node->name()).empty()) {
+ // Discard aggregate nodes with a single input.
+ if (node->input_size() == 1) {
+ return node->input(0);
+ }
+
+ // Try to rewrite aggregations of N >= 2 identical terms (possibly due
+ // to deduping or other rewrites) so we can get rid of the sum entirely.
+ // The expression (using AddN as an example of an aggregate op):
+ // AddN(x, x, x, ... ,x)
+ // <-- N terms -->
+ // can be rewritten to
+ // Mul(Const(N), x))
+ //
+ bool all_equal = true;
+ for (int i = 1; i < node->input_size(); ++i) {
+ if (node->input(i) != node->input(0)) {
+ all_equal = false;
+ break;
+ }
+ }
+ if (all_equal) {
+ // 1. Create constant node with value N.
+ const int N = node->input_size();
+ const auto type = GetDataTypeFromAttr(*node, "T");
+ Tensor t(type, TensorShape({}));
+ Status status = SetTensorValue(type, N, &t);
+ if (!status.ok()) {
+ LOG(WARNING) << "Failed to create const node: "
+ << status.error_message();
+ return "";
+ }
+ TensorValue value(&t);
+ NodeDef* new_const_node = graph_def->add_node();
+ *new_const_node =
+ ConstantFolding::CreateNodeDef(node->name() + "_const", value);
+ new_const_node->set_device(node->device());
+ node_map->AddNode(new_const_node->name(), new_const_node);
+ new_nodes->push_back(new_const_node);
+
+ // 2. Replace the aggregate node with Mul(Const(N), x).
+ NodeDef* new_mul_node = graph_def->add_node();
+ new_mul_node->set_name(node->name() + "_mul");
+ new_mul_node->set_op("Mul");
+ new_mul_node->set_device(node->device());
+ SetDataTypeToAttr(type, "T", new_mul_node);
+ node_map->AddNode(new_mul_node->name(), new_mul_node);
+ new_nodes->push_back(new_mul_node);
+ new_mul_node->add_input(new_const_node->name());
+ node_map->AddOutput(new_const_node->name(), new_mul_node->name());
+ new_mul_node->add_input(node->input(0));
+ node_map->AddOutput(node->input(0), new_mul_node->name());
+
+ AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0),
+ {new_const_node}, graph_def, node_map, frame_map);
+ return new_mul_node->name();
+ }
+ }
+
+ // Use the commutativity and (left- and right-) distributive property of
+ // multiplication over addition to hoist common factors out of aggregate nodes
+ // where all the inputs are Mul nodes. This pattern occurs frequently in
+ // regularization terms for the gradients during training.
+ if (node->input_size() > 1 && IsAggregate(*node) &&
+ !node_map->GetOutputs(node->name()).empty()) {
+ // Determine the set of common factors if the input nodes are all Mul nodes.
+ std::set<string> common_factors;
+ int i = 0;
+ while (i < node->input_size() && (i == 0 || !common_factors.empty())) {
+ const NodeDef* input = node_map->GetNode(node->input(i));
+ if (input->op() == "Mul") {
+ std::set<string> factors_i{input->input(0), input->input(1)};
+ if (i == 0) {
+ std::swap(common_factors, factors_i);
+ } else {
+ std::set<string> intersection;
+ std::set_intersection(
+ factors_i.begin(), factors_i.end(), common_factors.begin(),
+ common_factors.end(),
+ std::inserter(intersection, intersection.begin()));
+ std::swap(common_factors, intersection);
+ }
+ } else {
+ common_factors.clear();
+ break;
+ }
+ ++i;
+ }
+ if (common_factors.size() == 1) {
+ // In this case we have an expression of the form
+ // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+ // that can be rewritten as
+ // Mul(x, AddN(y1, y2, y3, ... yn))
+ // 1. Hoist non-shared factors up into AddN node.
+ const string& common_factor = *common_factors.begin();
+ NodeDef* new_mul_node = graph_def->add_node();
+ NodeDef* new_add_node = graph_def->add_node();
+ *new_add_node = *node;
+ new_add_node->set_name(node->name() + "_hoist");
+ new_nodes->push_back(new_add_node);
+ node_map->AddNode(new_add_node->name(), new_add_node);
+ for (int i = 0; i < node->input_size(); ++i) {
+ NodeDef* mul_node = node_map->GetNode(node->input(i));
+ int unique_factor_index = mul_node->input(0) == common_factor ? 1 : 0;
+ const string unique_factor = mul_node->input(unique_factor_index);
+ new_add_node->set_input(i, unique_factor);
+ // 2. Use a copy of the first Mul node for the outer multiplication.
+ if (i == 0) {
+ *new_mul_node = *mul_node;
+ new_mul_node->set_name(new_mul_node->name() + "_hoist");
+ new_mul_node->set_input(0, common_factor);
+ new_mul_node->set_input(1, new_add_node->name());
+ new_nodes->push_back(new_mul_node);
+ node_map->AddNode(new_mul_node->name(), new_mul_node);
+ }
+ }
+ // 3. Set the device of the new nodes to that of the common factor "x".
+ NodeDef* common_factor_node = node_map->GetNode(common_factor);
+ new_add_node->set_device(common_factor_node->device());
+ new_mul_node->set_device(common_factor_node->device());
+
+ // 4. Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+ {new_add_node}, graph_def, node_map, frame_map);
+ return new_mul_node->name();
+ }
+ }
return "";
}
@@ -681,9 +897,13 @@ class SetVector {
};
} // namespace
-void ArithmeticOptimizer::SimplifyArithmeticOps(
+Status ArithmeticOptimizer::SimplifyArithmeticOps(
GraphDef* optimized_graph) const {
NodeMap node_map(optimized_graph);
+ FrameMap frame_map;
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph, node_map,
+ &frame_map, &num_frames));
SetVector<const NodeDef*> nodes_to_simplify;
for (int i = 0; i < optimized_graph->node_size(); ++i) {
nodes_to_simplify.PushBack(optimized_graph->mutable_node()->Mutable(i));
@@ -691,8 +911,8 @@ void ArithmeticOptimizer::SimplifyArithmeticOps(
while (!nodes_to_simplify.Empty()) {
const NodeDef* node = nodes_to_simplify.PopBack();
std::vector<const NodeDef*> new_nodes;
- const string simplified_tensor =
- TrySimplifyAndReplaceUses(node, optimized_graph, &node_map, &new_nodes);
+ const string simplified_tensor = TrySimplifyAndReplaceUses(
+ node, optimized_graph, &node_map, &new_nodes, &frame_map);
if (simplified_tensor.empty()) {
continue;
}
@@ -730,6 +950,7 @@ void ArithmeticOptimizer::SimplifyArithmeticOps(
}
}
}
+ return Status::OK();
}
Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
@@ -743,7 +964,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
TF_RETURN_IF_ERROR(graph_properties.AnnotateOutputShapes(optimized_graph));
DedupComputations(optimized_graph);
- SimplifyArithmeticOps(optimized_graph);
+ TF_RETURN_IF_ERROR(SimplifyArithmeticOps(optimized_graph));
// Clear output shapes.
for (int i = 0; i < optimized_graph->node_size(); ++i) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 53cec11ff6..4d2e160ff4 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -46,7 +46,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
void DedupComputations(GraphDef* optimized_graph) const;
// Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
// transposes.
- void SimplifyArithmeticOps(GraphDef* optimized_graph) const;
+ Status SimplifyArithmeticOps(GraphDef* optimized_graph) const;
// Tries to simplify the expression that roots at `node` and replaces the uses
// of `node` to the simplified expression. Returns the name of the simplified
// tensor (e.g. "split:1") or an emtpy string if no simplification is
@@ -64,7 +64,8 @@ class ArithmeticOptimizer : public GraphOptimizer {
// NodeDef.
string TrySimplifyAndReplaceUses(
const NodeDef* node, GraphDef* graph_def, NodeMap* node_map,
- std::vector<const NodeDef*>* new_nodes) const;
+ std::vector<const NodeDef*>* new_nodes,
+ std::unordered_map<const NodeDef*, std::vector<int>>* frame_map) const;
std::unordered_set<string> nodes_to_preserve_;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 61c8b82ea0..5c3fdd2553 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -58,7 +58,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
- Output add = ops::Add(s.WithOpName("add"), c1, c2);
+ Output mul = ops::Mul(s.WithOpName("mul"), c1, c2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -70,20 +70,20 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
EXPECT_EQ(2, output.node_size());
const NodeDef& new_c1 = output.node(0);
EXPECT_EQ("c1", new_c1.name());
- const NodeDef& new_add = output.node(1);
- EXPECT_EQ("add", new_add.name());
- EXPECT_EQ(2, new_add.input_size());
- EXPECT_EQ("c1", new_add.input(0));
- EXPECT_EQ("c1", new_add.input(1));
+ const NodeDef& new_mul = output.node(1);
+ EXPECT_EQ("mul", new_mul.name());
+ EXPECT_EQ(2, new_mul.input_size());
+ EXPECT_EQ("c1", new_mul.input(0));
+ EXPECT_EQ("c1", new_mul.input(1));
}
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
- Output add1 = ops::Add(s.WithOpName("add1"), c1, c2);
- Output add2 = ops::Add(s.WithOpName("add2"), c2, c1);
- Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
+ Output mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -97,16 +97,16 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
EXPECT_EQ("c1", new_c1.name());
const NodeDef& new_c2 = output.node(1);
EXPECT_EQ("c2", new_c2.name());
- const NodeDef& new_add1 = output.node(2);
- EXPECT_EQ("add1", new_add1.name());
- EXPECT_EQ(2, new_add1.input_size());
- EXPECT_EQ("c1", new_add1.input(0));
- EXPECT_EQ("c2", new_add1.input(1));
- const NodeDef& new_add3 = output.node(3);
- EXPECT_EQ("add3", new_add3.name());
- EXPECT_EQ(2, new_add3.input_size());
- EXPECT_EQ("add1", new_add3.input(0));
- EXPECT_EQ("add1", new_add3.input(1));
+ const NodeDef& new_mul1 = output.node(2);
+ EXPECT_EQ("mul1", new_mul1.name());
+ EXPECT_EQ(2, new_mul1.input_size());
+ EXPECT_EQ("c1", new_mul1.input(0));
+ EXPECT_EQ("c2", new_mul1.input(1));
+ const NodeDef& new_mul3 = output.node(3);
+ EXPECT_EQ("mul3", new_mul3.name());
+ EXPECT_EQ(2, new_mul3.input_size());
+ EXPECT_EQ("mul1", new_mul3.input(0));
+ EXPECT_EQ("mul1", new_mul3.input(1));
}
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
@@ -131,6 +131,66 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
EXPECT_EQ("c", output.node(5).input(0));
}
+TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output add = ops::Add(s.WithOpName("add"), x, x);
+ Output id = ops::Identity(s.WithOpName("id"), add);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // VLOG(2) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ const NodeDef& new_const = output.node(3);
+ EXPECT_EQ("add_const", new_const.name());
+ const NodeDef& new_mul = output.node(4);
+ EXPECT_EQ("add_mul", new_mul.name());
+ EXPECT_EQ("add_const", new_mul.input(0));
+ EXPECT_EQ("x", new_mul.input(1));
+ const NodeDef& new_id = output.node(2);
+ EXPECT_EQ("id", new_id.name());
+ EXPECT_EQ("add_mul", new_id.input(0));
+}
+
+TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+ Output y2 = ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
+ Output add = ops::Add(s.WithOpName("add"), mul1, mul2);
+ Output id = ops::Identity(s.WithOpName("id"), add);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(9, output.node_size());
+ const NodeDef& new_add = output.node(8);
+ EXPECT_EQ("add_hoist", new_add.name());
+ EXPECT_EQ("y1", new_add.input(0));
+ EXPECT_EQ("y2", new_add.input(1));
+ const NodeDef& new_mul = output.node(7);
+ EXPECT_EQ("mul1_hoist", new_mul.name());
+ EXPECT_EQ("x", new_mul.input(0));
+ EXPECT_EQ("add_hoist", new_mul.input(1));
+ const NodeDef& new_id = output.node(6);
+ EXPECT_EQ("id", new_id.name());
+ EXPECT_EQ("mul1_hoist", new_id.input(0));
+}
+
TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ea03660440..e8ffff07c6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -100,8 +100,11 @@ ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
resource_mgr_.reset(new ResourceMgr());
}
-string ConstantFolding::AddControlDependency(const string& input_name) {
- const NodeDef* node = node_map_->GetNode(input_name);
+// static
+string ConstantFolding::AddControlDependency(const string& input_name,
+ GraphDef* graph,
+ NodeMap* node_map) {
+ const NodeDef* node = node_map->GetNode(input_name);
if (!IsSwitch(*node)) {
return AsControlDependency(*node);
} else {
@@ -111,7 +114,7 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
// dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency.
- auto outputs = node_map_->GetOutputs(node->name());
+ auto outputs = node_map->GetOutputs(node->name());
for (const NodeDef* node : outputs) {
if (IsIdentity(*node)) {
CHECK_EQ(1, node->input_size());
@@ -128,15 +131,15 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type();
- NodeDef* added_node = graph_.add_node();
+ NodeDef* added_node = graph->add_node();
added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity");
added_node->set_device(node->device());
(*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name;
- node_map_->AddNode(added_node->name(), added_node);
- node_map_->AddOutput(node->name(), added_node->name());
+ node_map->AddNode(added_node->name(), added_node);
+ node_map->AddOutput(node->name(), added_node->name());
return AsControlDependency(*added_node);
}
}
@@ -233,7 +236,8 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// ensure that the constant value will only be run in the
// cases where the shape/rank/size would have been run in
// the original graph. Additional inputs are extra control
- string ctrl_dep = AddControlDependency(node.input(0));
+ string ctrl_dep =
+ AddControlDependency(node.input(0), &graph_, node_map_.get());
node.set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
} else {
@@ -259,7 +263,8 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
- string ctrl_dep = AddControlDependency(node.name());
+ string ctrl_dep = AddControlDependency(node.name(), &graph_,
+ node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
@@ -370,6 +375,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return true;
}
+// static
NodeDef ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor) {
NodeDef node;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index b115e51dbf..30d778789a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -32,6 +32,10 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
// Constant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
public:
+ static NodeDef CreateNodeDef(const string& name, const TensorValue& tensor);
+ static string AddControlDependency(const string& input_name, GraphDef* graph,
+ NodeMap* node_map);
+
ConstantFolding(DeviceBase* cpu_device);
~ConstantFolding() override {}
@@ -45,14 +49,11 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
- string AddControlDependency(const string& input_name);
Status MaterializeShapes(const GrapplerItem& item,
const GraphProperties& properties);
bool IsFoldable(const NodeDef& node) const;
- NodeDef CreateNodeDef(const string& name, const TensorValue& tensor);
-
Status EvaluateNode(const NodeDef& node,
const gtl::InlinedVector<TensorValue, 4>& inputs,
gtl::InlinedVector<TensorValue, 4>* output) const;
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index d7d7218319..1ca296da0a 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1233,7 +1233,8 @@ class DataLayoutOptimizer : GraphProcessor {
Status Expand() {
int node_size_original = graph_->node_size();
std::unordered_map<const NodeDef*, std::vector<int>> frames;
- IdentifyFrames(*graph_, &frames);
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFrames(*graph_, &frames, &num_frames));
// This is the first pass where we expand the nodes which support NCHW.
std::set<string> ops_format_supported = GetOpsFormatSupported();
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index bb161bf9a4..21243833ac 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -78,6 +78,7 @@ cc_library(
hdrs = ["frame.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:op_types",
diff --git a/tensorflow/core/grappler/utils/frame.cc b/tensorflow/core/grappler/utils/frame.cc
index 7655d0bee5..df5f4ff7cf 100644
--- a/tensorflow/core/grappler/utils/frame.cc
+++ b/tensorflow/core/grappler/utils/frame.cc
@@ -20,27 +20,32 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace grappler {
-int IdentifyFrames(
- const GraphDef& graph,
- std::unordered_map<const NodeDef*, std::vector<int>>* frames) {
+Status IdentifyFrames(const GraphDef& graph, FrameMap* frame_map,
+ int* num_frames) {
NodeMap node_map(const_cast<GraphDef*>(&graph));
+ return IdentifyFramesWithNodeMap(graph, node_map, frame_map, num_frames);
+}
+
+Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map,
+ FrameMap* frame_map, int* num_frames) {
std::deque<std::pair<const NodeDef*, std::vector<int>>> ready_nodes;
for (const NodeDef& node : graph.node()) {
if (node.input_size() == 0) {
std::vector<int> empty;
ready_nodes.emplace_back(&node, empty);
- (*frames)[&node] = empty;
+ (*frame_map)[&node] = empty;
}
}
std::map<string, int> name_to_id;
while (!ready_nodes.empty()) {
auto ready_node = ready_nodes.front();
for (const auto& fanout : node_map.GetOutputs(ready_node.first->name())) {
- if (frames->count(fanout) < 1) {
+ if (frame_map->count(fanout) < 1) {
std::vector<int> frame_ids = ready_node.second;
if (IsExit(*ready_node.first)) {
frame_ids.pop_back();
@@ -59,9 +64,9 @@ int IdentifyFrames(
frame_ids.push_back(id);
}
ready_nodes.emplace_back(fanout, frame_ids);
- (*frames)[fanout] = frame_ids;
+ (*frame_map)[fanout] = frame_ids;
} else {
- auto frame_ids_fanout = (*frames)[fanout];
+ auto frame_ids_fanout = (*frame_map)[fanout];
auto frame_ids_node = ready_node.second;
if (IsEnter(*fanout)) {
frame_ids_fanout.pop_back();
@@ -69,12 +74,17 @@ int IdentifyFrames(
if (IsExit(*ready_node.first)) {
frame_ids_node.pop_back();
}
- CHECK(frame_ids_node == frame_ids_fanout);
+ if (frame_ids_node != frame_ids_fanout) {
+ return errors::InvalidArgument(
+ "Invalid graph: Frame ids for node ", ready_node.first->name(),
+ " does not match frame ids for it's fanout.");
+ }
}
}
ready_nodes.pop_front();
}
- return name_to_id.size();
+ *num_frames = name_to_id.size();
+ return Status::OK();
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/utils/frame.h b/tensorflow/core/grappler/utils/frame.h
index d9e046a969..be726ae795 100644
--- a/tensorflow/core/grappler/utils/frame.h
+++ b/tensorflow/core/grappler/utils/frame.h
@@ -18,16 +18,24 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
+using FrameMap = std::unordered_map<const NodeDef*, std::vector<int>>;
+
// Returns the number of frames present in the graph, and populates
// the 'frames' argument with the collection of frames (denoted by their
// frame ids) in the outermost-to-innermost order. Frame ids are arbitrary.
-int IdentifyFrames(
- const GraphDef& graph,
- std::unordered_map<const NodeDef*, std::vector<int>>* frames);
+Status IdentifyFrames(const GraphDef& graph, FrameMap* frame_map,
+ int* num_frames);
+
+// As above, but use an existing NodeMap for graph instead of building it
+// from scratch.
+Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map,
+ FrameMap* frame_map, int* num_frames);
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/frame_test.cc b/tensorflow/core/grappler/utils/frame_test.cc
index 30673eed7a..df76083fc3 100644
--- a/tensorflow/core/grappler/utils/frame_test.cc
+++ b/tensorflow/core/grappler/utils/frame_test.cc
@@ -78,7 +78,8 @@ TEST_F(IdentifyFramesTest, NestedLoop) {
*graph.add_node() = CreateNode("17", {"16"});
std::unordered_map<const NodeDef*, std::vector<int>> frames;
- int num_frames = IdentifyFrames(graph, &frames);
+ int num_frames;
+ EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok());
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}},
{"4", {0}}, {"5", {0}}, {"6", {0}}, {"7", {0, 1}},
@@ -108,7 +109,8 @@ TEST_F(IdentifyFramesTest, MultipleInputsToEnter) {
*graph.add_node() = CreateNode("3", "Exit", {"2"});
std::unordered_map<const NodeDef*, std::vector<int>> frames;
- int num_frames = IdentifyFrames(graph, &frames);
+ int num_frames;
+ EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok());
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {}}, {"2", {0}}, {"3", {0}}};
EXPECT_EQ(num_frames, 1);
@@ -135,7 +137,8 @@ TEST_F(IdentifyFramesTest, ExitOutput) {
*graph.add_node() = CreateNode("4", {"2", "3"});
std::unordered_map<const NodeDef*, std::vector<int>> frames;
- int num_frames = IdentifyFrames(graph, &frames);
+ int num_frames;
+ EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok());
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {}}, {"4", {}}};
EXPECT_EQ(num_frames, 1);
@@ -167,7 +170,8 @@ TEST_F(IdentifyFramesTest, MultipleEnterNodes) {
*graph.add_node() = CreateNode("9", "Exit", {"7"});
std::unordered_map<const NodeDef*, std::vector<int>> frames;
- int num_frames = IdentifyFrames(graph, &frames);
+ int num_frames;
+ EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok());
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}},
{"5", {}}, {"6", {0}}, {"7", {0}}, {"8", {0}}, {"9", {0}}};
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index a7c1d35399..847f9ec401 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -54,7 +54,9 @@ def _cli_config_from_temp_file():
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
- constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index aa5314dda5..1a6bedbbcb 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -38,7 +38,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index fd958367cb..e1ddd4ee64 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -53,7 +53,8 @@ from tensorflow.python.training import monitored_session
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 3b9a5d07c2..ed31a8c8cd 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -57,7 +57,8 @@ from tensorflow.python.training import gradient_descent
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -837,7 +838,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertIsNone(dump.find_some_path("delta", "v"))
def testCausalityCheckOnDumpsDetectsWrongTemporalOrder(self):
- with session.Session() as sess:
+ with session.Session(config=no_rewrite_session_config()) as sess:
u_name = "testDumpCausalityCheck/u"
v_name = "testDumpCausalityCheck/v"
w_name = "testDumpCausalityCheck/w"
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 863af0b924..9a3d0efabf 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -56,6 +56,7 @@ class StepperTest(test_util.TensorFlowTestCase):
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -590,6 +591,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -722,6 +724,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
config = config_pb2.ConfigProto(graph_options=graph_options)