diff options
author | 2017-11-02 16:32:47 -0700 | |
---|---|---|
committer | 2017-11-02 16:40:06 -0700 | |
commit | 6ace5e0494d8142dc67ca0714893afc716125917 (patch) | |
tree | 21bf67f21d8318b66b2cfea4cc65d83e3cc9b66b | |
parent | 3a8eaaf6a238e238a7adac9886b1569d7e43ae23 (diff) |
* Add optimization to hoist a common factor out of sums of products involving aggregate ops (AddN, Add, Accumulate) or eliminate the aggregation op entirely.
* Replace trivial aggregations of the form x+x+x... with const(N)*x for N > 1.
PiperOrigin-RevId: 174398543
16 files changed, 390 insertions, 67 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 681d26e262..669d02815c 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -161,6 +161,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":constant_folding", ":graph_optimizer", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -170,6 +171,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", ], ) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 78b55237d1..445e5cf972 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" + +#include <algorithm> +#include <limits> #include <unordered_map> #include <unordered_set> + #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -23,6 +27,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/util/device_name_utils.h" @@ -31,6 +38,45 @@ namespace tensorflow { namespace grappler { namespace { +template <typename T> +bool SafeSetTensorValue(double value, Tensor* tensor) { + using RealType = typename Eigen::NumTraits<T>::Real; + if (value > std::numeric_limits<RealType>::max() || + value < std::numeric_limits<RealType>::min()) { + return false; + } + tensor->flat<T>()(0) = static_cast<T>(value); + return true; +} + +#define HANDLE_CASE(DTYPE) \ + case DTYPE: \ + if (!SafeSetTensorValue<EnumToDataType<DTYPE>::Type>( \ + static_cast<double>(value), tensor)) { \ + return errors::InvalidArgument("Cannot store value ", value, \ + " in tensor of type " #DTYPE); \ + } \ + break + +Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { + switch (dtype) { + // HANDLE_CASE(DT_HALF); + HANDLE_CASE(DT_FLOAT); + HANDLE_CASE(DT_DOUBLE); + HANDLE_CASE(DT_UINT8); + HANDLE_CASE(DT_INT8); + HANDLE_CASE(DT_UINT16); + HANDLE_CASE(DT_INT16); + HANDLE_CASE(DT_INT32); + HANDLE_CASE(DT_INT64); + HANDLE_CASE(DT_COMPLEX64); + HANDLE_CASE(DT_COMPLEX128); + default: + return errors::InvalidArgument("Unexpected type ", DataTypeString(dtype)); + } + return Status::OK(); +} + static bool IsInvolution(const NodeDef& node) { const std::unordered_set<string> involution_ops = {"Conj", "Reciprocal", "Neg", "LogicalNot"}; @@ -107,14 +153,28 @@ DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { return attr.type(); } -bool IsCommutative(const OpDef& op, const NodeDef& input1) { - if (op.name() == "Add") { +bool IsCommutative(const NodeDef& node) { + if (node.op() == "Add" && node.input_size() > 0) { // Workaround for "Add" not being marked is_commutative and is_aggregate. // (See cl/173915048). - const auto type = GetDataTypeFromAttr(input1, "T"); + const auto type = GetDataTypeFromAttr(node, "T"); return type != DT_INVALID && type != DT_STRING; } - return op.is_commutative(); + const OpDef* op_def = nullptr; + const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + return status.ok() && op_def->is_commutative(); +} + +bool IsAggregate(const NodeDef& node) { + if (node.op() == "Add" && node.input_size() > 0) { + // Workaround for "Add" not being marked is_commutative and is_aggregate. + // (See cl/173915048). + const auto type = GetDataTypeFromAttr(node, "T"); + return type != DT_INVALID && type != DT_STRING; + } + const OpDef* op_def = nullptr; + const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + return status.ok() && op_def->is_aggregate(); } void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) { @@ -208,6 +268,30 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, return true; } +// Fix frame dependencies by adding control dependencies from old_input to nodes +// in new_nodes_for_control_dep, and update frame_map for all nodes in +// new_nodes. +void AddFrameControlDeps(const NodeDef* old_node, + const std::vector<NodeDef*>& new_nodes, + const string& source_for_ctrl_dep, + const std::vector<NodeDef*>& sinks_for_control_dep, + GraphDef* graph, NodeMap* node_map, + FrameMap* frame_map) { + const auto frame_it = frame_map->find(old_node); + if (frame_it != frame_map->end()) { + for (auto node : new_nodes) { + frame_map->emplace(node, frame_it->second); + } + if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { + const string ctrl_dep = ConstantFolding::AddControlDependency( + source_for_ctrl_dep, graph, node_map); + for (auto node : sinks_for_control_dep) { + node->add_input(ctrl_dep); + } + } + } +} + } // namespace class UniqueNodes { @@ -264,10 +348,7 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { } // Compare inputs. - const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(node1.op(), &op_def); - const bool is_commutative = status.ok() && IsCommutative(*op_def, node1); - if (is_commutative) { + if (IsCommutative(node1)) { std::vector<string> inputs1(node1.input().begin(), node1.input().end()); std::vector<string> inputs2(node2.input().begin(), node2.input().end()); std::sort(inputs1.begin(), inputs1.end()); @@ -399,7 +480,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const { string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, GraphDef* graph_def, NodeMap* node_map, - std::vector<const NodeDef*>* new_nodes) const { + std::vector<const NodeDef*>* new_nodes, FrameMap* frame_map) const { // Remove involutions applied twice. if (IsInvolution(*node)) { // An involution is a function f(x) that is its own inverse, @@ -519,6 +600,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_nodes->push_back(new_transpose); new_nodes->push_back(new_cast); + // Add frame dependencies that the original node might have had. + AddFrameControlDeps(node, {new_transpose, new_cast}, + new_transpose->input(0), {new_transpose}, + graph_def, node_map, frame_map); + return new_cast->name(); } } @@ -625,6 +711,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map->AddOutput(weights->name(), scaled_weights->name()); scaled_weights->add_input(mul->input(1)); node_map->AddOutput(scale->name(), scaled_weights->name()); + AddFrameControlDeps(node, {scaled_weights}, "", {}, graph_def, + node_map, frame_map); // Update `conv`'s weights to `scaled_weights`. conv->set_input(1, scaled_weights->name()); @@ -648,6 +736,134 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } + if (node->input_size() > 0 && IsAggregate(*node) && + !node_map->GetOutputs(node->name()).empty()) { + // Discard aggregate nodes with a single input. + if (node->input_size() == 1) { + return node->input(0); + } + + // Try to rewrite aggregations of N >= 2 identical terms (possibly due + // to deduping or other rewrites) so we can get rid of the sum entirely. + // The expression (using AddN as an example of an aggregate op): + // AddN(x, x, x, ... ,x) + // <-- N terms --> + // can be rewritten to + // Mul(Const(N), x)) + // + bool all_equal = true; + for (int i = 1; i < node->input_size(); ++i) { + if (node->input(i) != node->input(0)) { + all_equal = false; + break; + } + } + if (all_equal) { + // 1. Create constant node with value N. + const int N = node->input_size(); + const auto type = GetDataTypeFromAttr(*node, "T"); + Tensor t(type, TensorShape({})); + Status status = SetTensorValue(type, N, &t); + if (!status.ok()) { + LOG(WARNING) << "Failed to create const node: " + << status.error_message(); + return ""; + } + TensorValue value(&t); + NodeDef* new_const_node = graph_def->add_node(); + *new_const_node = + ConstantFolding::CreateNodeDef(node->name() + "_const", value); + new_const_node->set_device(node->device()); + node_map->AddNode(new_const_node->name(), new_const_node); + new_nodes->push_back(new_const_node); + + // 2. Replace the aggregate node with Mul(Const(N), x). + NodeDef* new_mul_node = graph_def->add_node(); + new_mul_node->set_name(node->name() + "_mul"); + new_mul_node->set_op("Mul"); + new_mul_node->set_device(node->device()); + SetDataTypeToAttr(type, "T", new_mul_node); + node_map->AddNode(new_mul_node->name(), new_mul_node); + new_nodes->push_back(new_mul_node); + new_mul_node->add_input(new_const_node->name()); + node_map->AddOutput(new_const_node->name(), new_mul_node->name()); + new_mul_node->add_input(node->input(0)); + node_map->AddOutput(node->input(0), new_mul_node->name()); + + AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0), + {new_const_node}, graph_def, node_map, frame_map); + return new_mul_node->name(); + } + } + + // Use the commutativity and (left- and right-) distributive property of + // multiplication over addition to hoist common factors out of aggregate nodes + // where all the inputs are Mul nodes. This pattern occurs frequently in + // regularization terms for the gradients during training. + if (node->input_size() > 1 && IsAggregate(*node) && + !node_map->GetOutputs(node->name()).empty()) { + // Determine the set of common factors if the input nodes are all Mul nodes. + std::set<string> common_factors; + int i = 0; + while (i < node->input_size() && (i == 0 || !common_factors.empty())) { + const NodeDef* input = node_map->GetNode(node->input(i)); + if (input->op() == "Mul") { + std::set<string> factors_i{input->input(0), input->input(1)}; + if (i == 0) { + std::swap(common_factors, factors_i); + } else { + std::set<string> intersection; + std::set_intersection( + factors_i.begin(), factors_i.end(), common_factors.begin(), + common_factors.end(), + std::inserter(intersection, intersection.begin())); + std::swap(common_factors, intersection); + } + } else { + common_factors.clear(); + break; + } + ++i; + } + if (common_factors.size() == 1) { + // In this case we have an expression of the form + // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) + // that can be rewritten as + // Mul(x, AddN(y1, y2, y3, ... yn)) + // 1. Hoist non-shared factors up into AddN node. + const string& common_factor = *common_factors.begin(); + NodeDef* new_mul_node = graph_def->add_node(); + NodeDef* new_add_node = graph_def->add_node(); + *new_add_node = *node; + new_add_node->set_name(node->name() + "_hoist"); + new_nodes->push_back(new_add_node); + node_map->AddNode(new_add_node->name(), new_add_node); + for (int i = 0; i < node->input_size(); ++i) { + NodeDef* mul_node = node_map->GetNode(node->input(i)); + int unique_factor_index = mul_node->input(0) == common_factor ? 1 : 0; + const string unique_factor = mul_node->input(unique_factor_index); + new_add_node->set_input(i, unique_factor); + // 2. Use a copy of the first Mul node for the outer multiplication. + if (i == 0) { + *new_mul_node = *mul_node; + new_mul_node->set_name(new_mul_node->name() + "_hoist"); + new_mul_node->set_input(0, common_factor); + new_mul_node->set_input(1, new_add_node->name()); + new_nodes->push_back(new_mul_node); + node_map->AddNode(new_mul_node->name(), new_mul_node); + } + } + // 3. Set the device of the new nodes to that of the common factor "x". + NodeDef* common_factor_node = node_map->GetNode(common_factor); + new_add_node->set_device(common_factor_node->device()); + new_mul_node->set_device(common_factor_node->device()); + + // 4. Add frame dependencies that the original node might have had. + AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, + {new_add_node}, graph_def, node_map, frame_map); + return new_mul_node->name(); + } + } return ""; } @@ -681,9 +897,13 @@ class SetVector { }; } // namespace -void ArithmeticOptimizer::SimplifyArithmeticOps( +Status ArithmeticOptimizer::SimplifyArithmeticOps( GraphDef* optimized_graph) const { NodeMap node_map(optimized_graph); + FrameMap frame_map; + int num_frames; + TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph, node_map, + &frame_map, &num_frames)); SetVector<const NodeDef*> nodes_to_simplify; for (int i = 0; i < optimized_graph->node_size(); ++i) { nodes_to_simplify.PushBack(optimized_graph->mutable_node()->Mutable(i)); @@ -691,8 +911,8 @@ void ArithmeticOptimizer::SimplifyArithmeticOps( while (!nodes_to_simplify.Empty()) { const NodeDef* node = nodes_to_simplify.PopBack(); std::vector<const NodeDef*> new_nodes; - const string simplified_tensor = - TrySimplifyAndReplaceUses(node, optimized_graph, &node_map, &new_nodes); + const string simplified_tensor = TrySimplifyAndReplaceUses( + node, optimized_graph, &node_map, &new_nodes, &frame_map); if (simplified_tensor.empty()) { continue; } @@ -730,6 +950,7 @@ void ArithmeticOptimizer::SimplifyArithmeticOps( } } } + return Status::OK(); } Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, @@ -743,7 +964,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, TF_RETURN_IF_ERROR(graph_properties.AnnotateOutputShapes(optimized_graph)); DedupComputations(optimized_graph); - SimplifyArithmeticOps(optimized_graph); + TF_RETURN_IF_ERROR(SimplifyArithmeticOps(optimized_graph)); // Clear output shapes. for (int i = 0; i < optimized_graph->node_size(); ++i) { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 53cec11ff6..4d2e160ff4 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -46,7 +46,7 @@ class ArithmeticOptimizer : public GraphOptimizer { void DedupComputations(GraphDef* optimized_graph) const; // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. - void SimplifyArithmeticOps(GraphDef* optimized_graph) const; + Status SimplifyArithmeticOps(GraphDef* optimized_graph) const; // Tries to simplify the expression that roots at `node` and replaces the uses // of `node` to the simplified expression. Returns the name of the simplified // tensor (e.g. "split:1") or an emtpy string if no simplification is @@ -64,7 +64,8 @@ class ArithmeticOptimizer : public GraphOptimizer { // NodeDef. string TrySimplifyAndReplaceUses( const NodeDef* node, GraphDef* graph_def, NodeMap* node_map, - std::vector<const NodeDef*>* new_nodes) const; + std::vector<const NodeDef*>* new_nodes, + std::unordered_map<const NodeDef*, std::vector<int>>* frame_map) const; std::unordered_set<string> nodes_to_preserve_; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 61c8b82ea0..5c3fdd2553 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -58,7 +58,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2}); Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2}); - Output add = ops::Add(s.WithOpName("add"), c1, c2); + Output mul = ops::Mul(s.WithOpName("mul"), c1, c2); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -70,20 +70,20 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { EXPECT_EQ(2, output.node_size()); const NodeDef& new_c1 = output.node(0); EXPECT_EQ("c1", new_c1.name()); - const NodeDef& new_add = output.node(1); - EXPECT_EQ("add", new_add.name()); - EXPECT_EQ(2, new_add.input_size()); - EXPECT_EQ("c1", new_add.input(0)); - EXPECT_EQ("c1", new_add.input(1)); + const NodeDef& new_mul = output.node(1); + EXPECT_EQ("mul", new_mul.name()); + EXPECT_EQ(2, new_mul.input_size()); + EXPECT_EQ("c1", new_mul.input(0)); + EXPECT_EQ("c1", new_mul.input(1)); } TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2}); Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2}); - Output add1 = ops::Add(s.WithOpName("add1"), c1, c2); - Output add2 = ops::Add(s.WithOpName("add2"), c2, c1); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); + Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2); + Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1); + Output mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -97,16 +97,16 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { EXPECT_EQ("c1", new_c1.name()); const NodeDef& new_c2 = output.node(1); EXPECT_EQ("c2", new_c2.name()); - const NodeDef& new_add1 = output.node(2); - EXPECT_EQ("add1", new_add1.name()); - EXPECT_EQ(2, new_add1.input_size()); - EXPECT_EQ("c1", new_add1.input(0)); - EXPECT_EQ("c2", new_add1.input(1)); - const NodeDef& new_add3 = output.node(3); - EXPECT_EQ("add3", new_add3.name()); - EXPECT_EQ(2, new_add3.input_size()); - EXPECT_EQ("add1", new_add3.input(0)); - EXPECT_EQ("add1", new_add3.input(1)); + const NodeDef& new_mul1 = output.node(2); + EXPECT_EQ("mul1", new_mul1.name()); + EXPECT_EQ(2, new_mul1.input_size()); + EXPECT_EQ("c1", new_mul1.input(0)); + EXPECT_EQ("c2", new_mul1.input(1)); + const NodeDef& new_mul3 = output.node(3); + EXPECT_EQ("mul3", new_mul3.name()); + EXPECT_EQ(2, new_mul3.input_size()); + EXPECT_EQ("mul1", new_mul3.input(0)); + EXPECT_EQ("mul1", new_mul3.input(1)); } TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { @@ -131,6 +131,66 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { EXPECT_EQ("c", output.node(5).input(0)); } +TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add"), x, x); + Output id = ops::Identity(s.WithOpName("id"), add); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // VLOG(2) << output.DebugString(); + EXPECT_EQ(5, output.node_size()); + const NodeDef& new_const = output.node(3); + EXPECT_EQ("add_const", new_const.name()); + const NodeDef& new_mul = output.node(4); + EXPECT_EQ("add_mul", new_mul.name()); + EXPECT_EQ("add_const", new_mul.input(0)); + EXPECT_EQ("x", new_mul.input(1)); + const NodeDef& new_id = output.node(2); + EXPECT_EQ("id", new_id.name()); + EXPECT_EQ("add_mul", new_id.input(0)); +} + +TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); + Output y2 = ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}); + Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1); + Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x); + Output add = ops::Add(s.WithOpName("add"), mul1, mul2); + Output id = ops::Identity(s.WithOpName("id"), add); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + LOG(INFO) << output.DebugString(); + EXPECT_EQ(9, output.node_size()); + const NodeDef& new_add = output.node(8); + EXPECT_EQ("add_hoist", new_add.name()); + EXPECT_EQ("y1", new_add.input(0)); + EXPECT_EQ("y2", new_add.input(1)); + const NodeDef& new_mul = output.node(7); + EXPECT_EQ("mul1_hoist", new_mul.name()); + EXPECT_EQ("x", new_mul.input(0)); + EXPECT_EQ("add_hoist", new_mul.input(1)); + const NodeDef& new_id = output.node(6); + EXPECT_EQ("id", new_id.name()); + EXPECT_EQ("mul1_hoist", new_id.input(0)); +} + TEST_F(ArithmeticOptimizerTest, IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ea03660440..e8ffff07c6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -100,8 +100,11 @@ ConstantFolding::ConstantFolding(DeviceBase* cpu_device) resource_mgr_.reset(new ResourceMgr()); } -string ConstantFolding::AddControlDependency(const string& input_name) { - const NodeDef* node = node_map_->GetNode(input_name); +// static +string ConstantFolding::AddControlDependency(const string& input_name, + GraphDef* graph, + NodeMap* node_map) { + const NodeDef* node = node_map->GetNode(input_name); if (!IsSwitch(*node)) { return AsControlDependency(*node); } else { @@ -111,7 +114,7 @@ string ConstantFolding::AddControlDependency(const string& input_name) { // dependency is only triggered when the corresponding output is triggered. // We start by looking for an identity node connected to the output of the // switch node, and use it to anchor the control dependency. - auto outputs = node_map_->GetOutputs(node->name()); + auto outputs = node_map->GetOutputs(node->name()); for (const NodeDef* node : outputs) { if (IsIdentity(*node)) { CHECK_EQ(1, node->input_size()); @@ -128,15 +131,15 @@ string ConstantFolding::AddControlDependency(const string& input_name) { ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl); const DataType output_type = node->attr().at("T").type(); - NodeDef* added_node = graph_.add_node(); + NodeDef* added_node = graph->add_node(); added_node->set_name(ctrl_dep_name); added_node->set_op("Identity"); added_node->set_device(node->device()); (*added_node->mutable_attr())["T"].set_type(output_type); *added_node->add_input() = input_name; - node_map_->AddNode(added_node->name(), added_node); - node_map_->AddOutput(node->name(), added_node->name()); + node_map->AddNode(added_node->name(), added_node); + node_map->AddOutput(node->name(), added_node->name()); return AsControlDependency(*added_node); } } @@ -233,7 +236,8 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, // ensure that the constant value will only be run in the // cases where the shape/rank/size would have been run in // the original graph. Additional inputs are extra control - string ctrl_dep = AddControlDependency(node.input(0)); + string ctrl_dep = + AddControlDependency(node.input(0), &graph_, node_map_.get()); node.set_input(0, ctrl_dep); node_map_->AddOutput(NodeName(ctrl_dep), node.name()); } else { @@ -259,7 +263,8 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, // We add a control dependency to the original ShapeN node, // so that the node will only be run if all inputs of the // original ShapeN node are run. - string ctrl_dep = AddControlDependency(node.name()); + string ctrl_dep = AddControlDependency(node.name(), &graph_, + node_map_.get()); *added_node->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); } @@ -370,6 +375,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return true; } +// static NodeDef ConstantFolding::CreateNodeDef(const string& name, const TensorValue& tensor) { NodeDef node; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index b115e51dbf..30d778789a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -32,6 +32,10 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; // Constant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: + static NodeDef CreateNodeDef(const string& name, const TensorValue& tensor); + static string AddControlDependency(const string& input_name, GraphDef* graph, + NodeMap* node_map); + ConstantFolding(DeviceBase* cpu_device); ~ConstantFolding() override {} @@ -45,14 +49,11 @@ class ConstantFolding : public GraphOptimizer { const GraphDef& optimize_output, double result) override; private: - string AddControlDependency(const string& input_name); Status MaterializeShapes(const GrapplerItem& item, const GraphProperties& properties); bool IsFoldable(const NodeDef& node) const; - NodeDef CreateNodeDef(const string& name, const TensorValue& tensor); - Status EvaluateNode(const NodeDef& node, const gtl::InlinedVector<TensorValue, 4>& inputs, gtl::InlinedVector<TensorValue, 4>* output) const; diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index d7d7218319..1ca296da0a 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1233,7 +1233,8 @@ class DataLayoutOptimizer : GraphProcessor { Status Expand() { int node_size_original = graph_->node_size(); std::unordered_map<const NodeDef*, std::vector<int>> frames; - IdentifyFrames(*graph_, &frames); + int num_frames; + TF_RETURN_IF_ERROR(IdentifyFrames(*graph_, &frames, &num_frames)); // This is the first pass where we expand the nodes which support NCHW. std::set<string> ops_format_supported = GetOpsFormatSupported(); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index bb161bf9a4..21243833ac 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -78,6 +78,7 @@ cc_library( hdrs = ["frame.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:op_types", diff --git a/tensorflow/core/grappler/utils/frame.cc b/tensorflow/core/grappler/utils/frame.cc index 7655d0bee5..df5f4ff7cf 100644 --- a/tensorflow/core/grappler/utils/frame.cc +++ b/tensorflow/core/grappler/utils/frame.cc @@ -20,27 +20,32 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace grappler { -int IdentifyFrames( - const GraphDef& graph, - std::unordered_map<const NodeDef*, std::vector<int>>* frames) { +Status IdentifyFrames(const GraphDef& graph, FrameMap* frame_map, + int* num_frames) { NodeMap node_map(const_cast<GraphDef*>(&graph)); + return IdentifyFramesWithNodeMap(graph, node_map, frame_map, num_frames); +} + +Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map, + FrameMap* frame_map, int* num_frames) { std::deque<std::pair<const NodeDef*, std::vector<int>>> ready_nodes; for (const NodeDef& node : graph.node()) { if (node.input_size() == 0) { std::vector<int> empty; ready_nodes.emplace_back(&node, empty); - (*frames)[&node] = empty; + (*frame_map)[&node] = empty; } } std::map<string, int> name_to_id; while (!ready_nodes.empty()) { auto ready_node = ready_nodes.front(); for (const auto& fanout : node_map.GetOutputs(ready_node.first->name())) { - if (frames->count(fanout) < 1) { + if (frame_map->count(fanout) < 1) { std::vector<int> frame_ids = ready_node.second; if (IsExit(*ready_node.first)) { frame_ids.pop_back(); @@ -59,9 +64,9 @@ int IdentifyFrames( frame_ids.push_back(id); } ready_nodes.emplace_back(fanout, frame_ids); - (*frames)[fanout] = frame_ids; + (*frame_map)[fanout] = frame_ids; } else { - auto frame_ids_fanout = (*frames)[fanout]; + auto frame_ids_fanout = (*frame_map)[fanout]; auto frame_ids_node = ready_node.second; if (IsEnter(*fanout)) { frame_ids_fanout.pop_back(); @@ -69,12 +74,17 @@ int IdentifyFrames( if (IsExit(*ready_node.first)) { frame_ids_node.pop_back(); } - CHECK(frame_ids_node == frame_ids_fanout); + if (frame_ids_node != frame_ids_fanout) { + return errors::InvalidArgument( + "Invalid graph: Frame ids for node ", ready_node.first->name(), + " does not match frame ids for it's fanout."); + } } } ready_nodes.pop_front(); } - return name_to_id.size(); + *num_frames = name_to_id.size(); + return Status::OK(); } } // namespace grappler diff --git a/tensorflow/core/grappler/utils/frame.h b/tensorflow/core/grappler/utils/frame.h index d9e046a969..be726ae795 100644 --- a/tensorflow/core/grappler/utils/frame.h +++ b/tensorflow/core/grappler/utils/frame.h @@ -18,16 +18,24 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { +using FrameMap = std::unordered_map<const NodeDef*, std::vector<int>>; + // Returns the number of frames present in the graph, and populates // the 'frames' argument with the collection of frames (denoted by their // frame ids) in the outermost-to-innermost order. Frame ids are arbitrary. -int IdentifyFrames( - const GraphDef& graph, - std::unordered_map<const NodeDef*, std::vector<int>>* frames); +Status IdentifyFrames(const GraphDef& graph, FrameMap* frame_map, + int* num_frames); + +// As above, but use an existing NodeMap for graph instead of building it +// from scratch. +Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map, + FrameMap* frame_map, int* num_frames); } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/frame_test.cc b/tensorflow/core/grappler/utils/frame_test.cc index 30673eed7a..df76083fc3 100644 --- a/tensorflow/core/grappler/utils/frame_test.cc +++ b/tensorflow/core/grappler/utils/frame_test.cc @@ -78,7 +78,8 @@ TEST_F(IdentifyFramesTest, NestedLoop) { *graph.add_node() = CreateNode("17", {"16"}); std::unordered_map<const NodeDef*, std::vector<int>> frames; - int num_frames = IdentifyFrames(graph, &frames); + int num_frames; + EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok()); std::unordered_map<string, std::vector<int>> expected = { {"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}}, {"5", {0}}, {"6", {0}}, {"7", {0, 1}}, @@ -108,7 +109,8 @@ TEST_F(IdentifyFramesTest, MultipleInputsToEnter) { *graph.add_node() = CreateNode("3", "Exit", {"2"}); std::unordered_map<const NodeDef*, std::vector<int>> frames; - int num_frames = IdentifyFrames(graph, &frames); + int num_frames; + EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok()); std::unordered_map<string, std::vector<int>> expected = { {"0", {}}, {"1", {}}, {"2", {0}}, {"3", {0}}}; EXPECT_EQ(num_frames, 1); @@ -135,7 +137,8 @@ TEST_F(IdentifyFramesTest, ExitOutput) { *graph.add_node() = CreateNode("4", {"2", "3"}); std::unordered_map<const NodeDef*, std::vector<int>> frames; - int num_frames = IdentifyFrames(graph, &frames); + int num_frames; + EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok()); std::unordered_map<string, std::vector<int>> expected = { {"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {}}, {"4", {}}}; EXPECT_EQ(num_frames, 1); @@ -167,7 +170,8 @@ TEST_F(IdentifyFramesTest, MultipleEnterNodes) { *graph.add_node() = CreateNode("9", "Exit", {"7"}); std::unordered_map<const NodeDef*, std::vector<int>> frames; - int num_frames = IdentifyFrames(graph, &frames); + int num_frames; + EXPECT_TRUE(IdentifyFrames(graph, &frames, &num_frames).ok()); std::unordered_map<string, std::vector<int>> expected = { {"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}}, {"5", {}}, {"6", {0}}, {"7", {0}}, {"8", {0}}, {"9", {0}}}; diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index a7c1d35399..847f9ec401 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -54,7 +54,9 @@ def _cli_config_from_temp_file(): def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, - constant_folding=rewriter_config_pb2.RewriterConfig.OFF) + constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py index aa5314dda5..1a6bedbbcb 100644 --- a/tensorflow/python/debug/lib/session_debug_file_test.py +++ b/tensorflow/python/debug/lib/session_debug_file_test.py @@ -38,7 +38,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True) + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index fd958367cb..e1ddd4ee64 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -53,7 +53,8 @@ from tensorflow.python.training import monitored_session def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True) + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 3b9a5d07c2..ed31a8c8cd 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -57,7 +57,8 @@ from tensorflow.python.training import gradient_descent def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True) + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) @@ -837,7 +838,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): self.assertIsNone(dump.find_some_path("delta", "v")) def testCausalityCheckOnDumpsDetectsWrongTemporalOrder(self): - with session.Session() as sess: + with session.Session(config=no_rewrite_session_config()) as sess: u_name = "testDumpCausalityCheck/u" v_name = "testDumpCausalityCheck/v" w_name = "testDumpCausalityCheck/w" diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py index 863af0b924..9a3d0efabf 100644 --- a/tensorflow/python/debug/lib/stepper_test.py +++ b/tensorflow/python/debug/lib/stepper_test.py @@ -56,6 +56,7 @@ class StepperTest(test_util.TensorFlowTestCase): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, constant_folding=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -590,6 +591,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, constant_folding=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -722,6 +724,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, constant_folding=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) config = config_pb2.ConfigProto(graph_options=graph_options) |