diff options
author | 2018-03-09 18:50:06 -0800 | |
---|---|---|
committer | 2018-03-09 18:54:18 -0800 | |
commit | 9d1d5057b9d3fb335a4b20193bb364737e2b5140 (patch) | |
tree | 14f940946abf2d9d4d95c96264566e41a6edbb3f | |
parent | 2426308fa58ebf473092918cc8ffa215325c4079 (diff) |
Move optimizations to arithmetic optimizer stages
1) Redundant Bitcast
2) Redundant Cast
3) Remove inverse transpose
PiperOrigin-RevId: 188569367
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 207 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.h | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 140 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.h | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test_test.cc | 100 |
10 files changed, 370 insertions, 131 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 8cf1402ae8..ca56833ef6 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -78,6 +78,10 @@ bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } bool IsConj(const NodeDef& node) { return node.op() == "Conj"; } +bool IsConjugateTranspose(const NodeDef& node) { + return node.op() == "ConjugateTranspose"; +} + bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2DBackpropFilter(const NodeDef& node) { diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index a7c33ef97b..a0946ee1ad 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -40,6 +40,8 @@ bool IsCast(const NodeDef& node); bool IsComplex(const NodeDef& node); bool IsComplexAbs(const NodeDef& node); bool IsConj(const NodeDef& node); +bool IsConjugateTranspose(const NodeDef& node); +bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); bool IsConv2D(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 7ec137373b..6ded261c7d 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -248,6 +248,7 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", ], ) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3cf42fde41..177b0735e9 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -45,19 +45,6 @@ namespace tensorflow { namespace grappler { namespace { -template <typename T> -bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) { - if (a.size() != b.size()) { - return false; - } - for (int i = 0; i < a.size(); ++i) { - if (a[b[i]] != i) { - return false; - } - } - return true; -} - // Extract values from a Const op to `values`. Returns true if succeeds. template <typename T> bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) { @@ -431,9 +418,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { Status TrySimplify(const NodeDef* node, string* simplified_node_name) override { - CHECK(IsSupported(node)) - << "Node " << node->name() - << " is not supported by add ops group optimizer step"; + CHECK(IsSupported(node)); AddOpsGroup group; TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group)); @@ -650,6 +635,130 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { std::unordered_set<string> rewritten_nodes_; }; +// Removes inverse transpose nodes +class RemoveInverseTranspose : public ArithmeticOptimizerStage { + public: + explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx) + : ArithmeticOptimizerStage(ctx) {} + ~RemoveInverseTranspose() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsTranspose(*node) || IsConjugateTranspose(*node); + } + + Status TrySimplify(const NodeDef* node, + string* simplified_node_name) override { + CHECK(IsSupported(node)); + + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + + if (input->op() == node->op()) { + NodeDef* node_perm; + NodeDef* input_perm; + + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm)); + + // Try 32-bit indices. + std::vector<int> node_perm_values; + std::vector<int> input_perm_values; + if (ValuesFromConstNode(*node_perm, &node_perm_values) && + ValuesFromConstNode(*input_perm, &input_perm_values) && + AreInversePermutations(node_perm_values, input_perm_values)) { + *simplified_node_name = input->input(0); + } + // Try 64-bit indices. + std::vector<int64> node_perm_values64; + std::vector<int64> input_perm_values64; + if (ValuesFromConstNode(*node_perm, &node_perm_values64) && + ValuesFromConstNode(*input_perm, &input_perm_values64) && + AreInversePermutations(node_perm_values64, input_perm_values64)) { + *simplified_node_name = input->input(0); + } + } + + return Status::OK(); + } + + private: + template <typename T> + bool AreInversePermutations(const std::vector<T>& a, + const std::vector<T>& b) { + if (a.size() != b.size()) { + return false; + } + for (int i = 0; i < a.size(); ++i) { + if (a[b[i]] != i) { + return false; + } + } + return true; + } +}; + +// Remove redundant Bitcasts. +// 1) Remove Bitcast whose source type and destination type are equal +// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) +class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { + public: + explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx) + : ArithmeticOptimizerStage(ctx) {} + ~RemoveRedundantBitcastStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsBitcast(*node); + } + + Status TrySimplify(const NodeDef* node, + string* simplified_node_name) override { + CHECK(IsSupported(node)); + + // Bypass Bitcast whose source type and destination type are equal. + if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + NodeDef* bitcast; + TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast)); + NodeDef* operand; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand)); + + if (IsBitcast(*operand)) { + // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) + bitcast->set_input(0, operand->input(0)); + SetSourceDataType(GetSourceDataType(*operand), bitcast); + ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0), + operand->input(0)); + AddToOptimizationQueue(bitcast); + *simplified_node_name = bitcast->name(); + } + + return Status::OK(); + } +}; + +// Remove Casts whose source type and destination type are equal. +class RemoveRedundantCastStage : public ArithmeticOptimizerStage { + public: + explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx) + : ArithmeticOptimizerStage(ctx) {} + ~RemoveRedundantCastStage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsCast(*node); } + + Status TrySimplify(const NodeDef* node, + string* simplified_node_name) override { + CHECK(IsSupported(node)); + // Bypass Cast whose source type and destination type are equal. + if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { + *simplified_node_name = node->input(0); + } + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -903,31 +1012,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } - // Remove inverse transposes. - if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") { - NodeDef* input = node_map_->GetNode(node->input(0)); - if (input->op() == node->op()) { - const NodeDef* node_perm = node_map_->GetNode(node->input(1)); - const NodeDef* input_perm = node_map_->GetNode(input->input(1)); - // Try 32-bit indices. - std::vector<int> node_perm_values; - std::vector<int> input_perm_values; - if (ValuesFromConstNode(*node_perm, &node_perm_values) && - ValuesFromConstNode(*input_perm, &input_perm_values) && - AreInversePermutations(node_perm_values, input_perm_values)) { - return input->input(0); - } - // Try 64-bit indices. - std::vector<int64> node_perm_values64; - std::vector<int64> input_perm_values64; - if (ValuesFromConstNode(*node_perm, &node_perm_values64) && - ValuesFromConstNode(*input_perm, &input_perm_values64) && - AreInversePermutations(node_perm_values64, input_perm_values64)) { - return input->input(0); - } - } - } - if (node->op() == "Reshape") { // Reshape // ^ @@ -1024,32 +1108,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } - if (node->op() == "Bitcast") { - NodeDef* bitcast = node_map_->GetNode(node->name()); - // Bypass bitcasts whose source type and destination type are equal. - if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) { - return bitcast->input(0); - } - - const NodeDef* operand = node_map_->GetNode(bitcast->input(0)); - if (operand->op() == bitcast->op()) { - // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) - bitcast->set_input(0, operand->input(0)); - SetSourceDataType(GetSourceDataType(*operand), bitcast); - node_map_->UpdateInput(bitcast->name(), bitcast->input(0), - operand->input(0)); - nodes_to_simplify->PushBack(bitcast); - return bitcast->name(); - } - } - - if (node->op() == "Cast") { - // Bypass casts whose source type and destination type are equal. - if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { - return node->input(0); - } - } - // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize @@ -1391,11 +1449,22 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages; - // Add/AddN tree rewrites - if (options_.enable_add_to_addn_combining) { + if (options_.combine_add_to_addn) { stages.push_back( std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx))); } + if (options_.remove_inverse_transpose) { + stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( + new RemoveInverseTranspose(ctx))); + } + if (options_.remove_redundant_bitcast) { + stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( + new RemoveRedundantBitcastStage(ctx))); + } + if (options_.remove_redundant_cast) { + stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( + new RemoveRedundantCastStage(ctx))); + } VLOG(1) << "Simplify arithmetic ops using " << stages.size() << " arithmetic optimization stages"; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9cff8ca9d0..787084454d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -55,14 +55,16 @@ class ArithmeticOptimizer : public GraphOptimizer { // Granular control for arithmetic optimizer stages struct ArithmeticOptimizerOptions { - // rewrite a tree of Add/AddN ops with a single AddN - bool enable_add_to_addn_combining; + bool combine_add_to_addn = true; + bool remove_inverse_transpose = true; + bool remove_redundant_bitcast = true; + bool remove_redundant_cast = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. static ArithmeticOptimizerOptions Default( RewriterConfig::Toggle opt_level) { - return {/*enable_add_to_addn_combining*/ true}; + return ArithmeticOptimizerOptions(); } }; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index a56351c18a..98842b29f1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -49,7 +50,7 @@ void VerifyGraphsMatch(const GraphDef& original_graph, } } // namespace -class ArithmeticOptimizerTest : public ::testing::Test { +class ArithmeticOptimizerTest : public GrapplerTest { protected: // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no // longer have any output consumers. @@ -63,14 +64,32 @@ class ArithmeticOptimizerTest : public ::testing::Test { // TODO(ezhulenev): Make private. After migration to stages each test // should explicitly enable required optimization for tests isolation void DisableAllStages(ArithmeticOptimizer* optimizer) { - ArithmeticOptimizer::ArithmeticOptimizerOptions options{ - /*enable_add_to_addn_combining*/ false}; + ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.combine_add_to_addn = false; + options.remove_inverse_transpose = false; + options.remove_redundant_bitcast = false; + options.remove_redundant_cast = false; optimizer->options_ = options; } - void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) { + void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); - optimizer->options_.enable_add_to_addn_combining = true; + optimizer->options_.combine_add_to_addn = true; + } + + void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_inverse_transpose = true; + } + + void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_bitcast = true; + } + + void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_cast = true; } }; @@ -658,9 +677,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - EXPECT_EQ(0, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Reshape"; })); + EXPECT_EQ(0, CountOpNodes(output, "Reshape")); } TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { @@ -682,9 +699,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - EXPECT_EQ(1, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Reshape"; })); + EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { @@ -704,9 +719,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - EXPECT_EQ(1, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Reshape"; })); + EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } TEST_F(ArithmeticOptimizerTest, CombineReshapes) { @@ -737,9 +750,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { item.graph.Swap(&output); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - EXPECT_EQ(1, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Reshape"; })); + EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) { @@ -826,10 +837,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInverseTranspose(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); std::set<string> nodes_after_optimization; for (const NodeDef& node : output.node()) { @@ -859,10 +869,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInverseTranspose(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); for (const NodeDef& node : output.node()) { if (node.op() == "Concat") { @@ -886,10 +895,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInverseTranspose(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); NodeMap node_map(&output); const NodeDef* outputs_node = node_map.GetNode("outputs"); @@ -915,10 +925,9 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInverseTranspose(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); EXPECT_EQ(6, output.node_size()); } @@ -1133,10 +1142,10 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = - ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3})); - Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); - Output bc2 = ops::Bitcast(s, bc1, DT_INT8); + Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8, + ops::Placeholder::Shape({2, 3})); + Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8); + Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8); Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); GrapplerItem item; @@ -1144,18 +1153,22 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantBitcast(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + NodeMap node_map(&output); - EXPECT_EQ(1, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Bitcast"; })); + // Bitcasts combined into a single op and inputs redirected to updated Bitcast + EXPECT_EQ(3, output.node_size()); + EXPECT_EQ(1, CountOpNodes(output, "Bitcast")); + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2")); } TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); + Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8, + ops::Placeholder::Shape({2, 3})); Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); Output bc2 = ops::Bitcast(s, bc1, DT_INT8); Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); @@ -1163,33 +1176,42 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantBitcast(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + NodeMap node_map(&output); - EXPECT_EQ(0, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Bitcast"; })); + // Bitcasts removed and inputs redirected to outputs + EXPECT_EQ(2, output.node_size()); + EXPECT_EQ(0, CountOpNodes(output, "Bitcast")); + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); } TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); + Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8, + ops::Placeholder::Shape({2, 3})); Output cast = ops::Cast(s, inputs, DT_INT8); Output outputs = ops::Identity(s.WithOpName("outputs"), cast); GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantCast(&optimizer); - EXPECT_EQ(0, std::count_if( - output.node().begin(), output.node().end(), - [](const NodeDef& node) { return node.op() == "Cast"; })); + OptimizeAndPrune(&optimizer, &item, &output); + NodeMap node_map(&output); + + // Cast removed and inputs redirected to outputs + EXPECT_EQ(2, output.node_size()); + EXPECT_EQ(0, CountOpNodes(output, "Cast")); + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); } TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { @@ -1211,7 +1233,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { GraphDef output; ArithmeticOptimizer optimizer; - EnableAddToAddNCombining(&optimizer); + EnableOnlyAddToAddNCombining(&optimizer); OptimizeAndPrune(&optimizer, &item, &output); @@ -1266,7 +1288,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { GraphDef output; ArithmeticOptimizer optimizer; - EnableAddToAddNCombining(&optimizer); + EnableOnlyAddToAddNCombining(&optimizer); OptimizeAndPrune(&optimizer, &item, &output); @@ -1329,7 +1351,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) { GraphDef output; ArithmeticOptimizer optimizer; - EnableAddToAddNCombining(&optimizer); + EnableOnlyAddToAddNCombining(&optimizer); OptimizeAndPrune(&optimizer, &item, &output); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 3dbad40cae..939031c44b 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -147,6 +147,22 @@ cc_library( ], ) +tf_cc_test( + name = "grappler_test_test", + size = "small", + srcs = ["grappler_test_test.cc"], + deps = [ + ":grappler_test", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:utils", + ], +) + cc_library( name = "functions", srcs = [ diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 79b2aa2808..89c3aa82bf 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -90,5 +90,20 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) { } } +bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map, + const string& src, + const string& dst, int position) { + const NodeDef* src_node = node_map.GetNode(src); + const NodeDef* dst_node = node_map.GetNode(dst); + EXPECT_TRUE(src_node != nullptr) << src << " node not found"; + EXPECT_TRUE(dst_node != nullptr) << dst << " node not found"; + return src_node && dst_node && dst_node->input(position) == src_node->name(); +} + +int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) { + return std::count_if(graph.node().begin(), graph.node().end(), + [&op](const NodeDef& node) { return node.op() == op; }); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index fd6809b6e2..3df6625d5c 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -37,6 +38,13 @@ class GrapplerTest : public ::testing::Test { const std::vector<string>& inputs, GraphDef* graph); void CompareGraphs(GraphDef want, GraphDef got); + + // Check if node 'src' is directly connected to the input($position) of 'dst'. + bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, + const string& dst, int position = 0); + + // Count nodes of the given op-type in a graph. + int CountOpNodes(const GraphDef& graph, const string& op); }; } // end namespace grappler diff --git a/tensorflow/core/grappler/utils/grappler_test_test.cc b/tensorflow/core/grappler/utils/grappler_test_test.cc new file mode 100644 index 0000000000..677fa5a798 --- /dev/null +++ b/tensorflow/core/grappler/utils/grappler_test_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +// TODO(ezhulenev): add tests for all methods in GrapplerTest +class GrapplerTestTest : public GrapplerTest {}; + +TEST_F(GrapplerTestTest, CompareIdenticalGraphs) { + tensorflow::Scope s1 = tensorflow::Scope::NewRootScope(); + auto s1_a = ops::Variable(s1.WithOpName("a"), {2, 2}, DT_FLOAT); + auto s1_b = ops::Variable(s1.WithOpName("b"), {2, 2}, DT_FLOAT); + auto s1_add = ops::Add(s1.WithOpName("Add_1"), s1_a, s1_b); + + tensorflow::Scope s2 = tensorflow::Scope::NewRootScope(); + auto s2_a = ops::Variable(s2.WithOpName("a"), {2, 2}, DT_FLOAT); + auto s2_b = ops::Variable(s2.WithOpName("b"), {2, 2}, DT_FLOAT); + auto s2_add = ops::Add(s2.WithOpName("Add_1"), s2_a, s2_b); + + GraphDef graph1; + TF_ASSERT_OK(s1.ToGraphDef(&graph1)); + + GraphDef graph2; + TF_ASSERT_OK(s2.ToGraphDef(&graph2)); + + CompareGraphs(graph1, graph2); +} + +TEST_F(GrapplerTestTest, CheckNodesConnectivity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT); + auto add_1 = ops::Add(s.WithOpName("Add_1"), a, b); + auto add_2 = ops::Add(s.WithOpName("Add_2"), add_1, b); + + GraphDef graph; + TF_ASSERT_OK(s.ToGraphDef(&graph)); + + NodeMap node_map(&graph); + + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "a", "Add_1", 0)); + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_1", 1)); + EXPECT_FALSE(IsNodesDirectlyConnected(node_map, "a", "Add_2", 0)); + EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_2", 1)); +} + +TEST_F(GrapplerTestTest, CountOpNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT); + auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT); + + auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b); + auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c); + + auto mul_ab = ops::Mul(s.WithOpName("Mull_ab"), a, b); + auto mul_bc = ops::Mul(s.WithOpName("Mull_bc"), a, b); + + InputList inputs{ + Output(add_ab), + Output(add_bc), + Output(mul_ab), + Output(mul_bc), + }; + auto add_all = ops::AddN(s.WithOpName("Add_all"), inputs); + + GraphDef graph; + TF_ASSERT_OK(s.ToGraphDef(&graph)); + + EXPECT_EQ(2, CountOpNodes(graph, "Add")); + EXPECT_EQ(2, CountOpNodes(graph, "Mul")); + EXPECT_EQ(1, CountOpNodes(graph, "AddN")); + EXPECT_EQ(0, CountOpNodes(graph, "Transpose")); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow
\ No newline at end of file |