aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 18:50:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 18:54:18 -0800
commit9d1d5057b9d3fb335a4b20193bb364737e2b5140 (patch)
tree14f940946abf2d9d4d95c96264566e41a6edbb3f
parent2426308fa58ebf473092918cc8ffa215325c4079 (diff)
Move optimizations to arithmetic optimizer stages
1) Redundant Bitcast 2) Redundant Cast 3) Remove inverse transpose PiperOrigin-RevId: 188569367
-rw-r--r--tensorflow/core/grappler/op_types.cc4
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc207
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h8
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc140
-rw-r--r--tensorflow/core/grappler/utils/BUILD16
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc15
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h8
-rw-r--r--tensorflow/core/grappler/utils/grappler_test_test.cc100
10 files changed, 370 insertions, 131 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 8cf1402ae8..ca56833ef6 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -78,6 +78,10 @@ bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
+bool IsConjugateTranspose(const NodeDef& node) {
+ return node.op() == "ConjugateTranspose";
+}
+
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index a7c33ef97b..a0946ee1ad 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -40,6 +40,8 @@ bool IsCast(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
+bool IsConjugateTranspose(const NodeDef& node);
+bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 7ec137373b..6ded261c7d 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -248,6 +248,7 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3cf42fde41..177b0735e9 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -45,19 +45,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-template <typename T>
-bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
- if (a.size() != b.size()) {
- return false;
- }
- for (int i = 0; i < a.size(); ++i) {
- if (a[b[i]] != i) {
- return false;
- }
- }
- return true;
-}
-
// Extract values from a Const op to `values`. Returns true if succeeds.
template <typename T>
bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
@@ -431,9 +418,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
Status TrySimplify(const NodeDef* node,
string* simplified_node_name) override {
- CHECK(IsSupported(node))
- << "Node " << node->name()
- << " is not supported by add ops group optimizer step";
+ CHECK(IsSupported(node));
AddOpsGroup group;
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
@@ -650,6 +635,130 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
std::unordered_set<string> rewritten_nodes_;
};
+// Removes inverse transpose nodes
+class RemoveInverseTranspose : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
+ : ArithmeticOptimizerStage(ctx) {}
+ ~RemoveInverseTranspose() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsTranspose(*node) || IsConjugateTranspose(*node);
+ }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node));
+
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ if (input->op() == node->op()) {
+ NodeDef* node_perm;
+ NodeDef* input_perm;
+
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
+
+ // Try 32-bit indices.
+ std::vector<int> node_perm_values;
+ std::vector<int> input_perm_values;
+ if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
+ ValuesFromConstNode(*input_perm, &input_perm_values) &&
+ AreInversePermutations(node_perm_values, input_perm_values)) {
+ *simplified_node_name = input->input(0);
+ }
+ // Try 64-bit indices.
+ std::vector<int64> node_perm_values64;
+ std::vector<int64> input_perm_values64;
+ if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
+ ValuesFromConstNode(*input_perm, &input_perm_values64) &&
+ AreInversePermutations(node_perm_values64, input_perm_values64)) {
+ *simplified_node_name = input->input(0);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ template <typename T>
+ bool AreInversePermutations(const std::vector<T>& a,
+ const std::vector<T>& b) {
+ if (a.size() != b.size()) {
+ return false;
+ }
+ for (int i = 0; i < a.size(); ++i) {
+ if (a[b[i]] != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+// Remove redundant Bitcasts.
+// 1) Remove Bitcast whose source type and destination type are equal
+// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
+class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
+ : ArithmeticOptimizerStage(ctx) {}
+ ~RemoveRedundantBitcastStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsBitcast(*node);
+ }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node));
+
+ // Bypass Bitcast whose source type and destination type are equal.
+ if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ NodeDef* bitcast;
+ TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
+ NodeDef* operand;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
+
+ if (IsBitcast(*operand)) {
+ // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
+ bitcast->set_input(0, operand->input(0));
+ SetSourceDataType(GetSourceDataType(*operand), bitcast);
+ ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0),
+ operand->input(0));
+ AddToOptimizationQueue(bitcast);
+ *simplified_node_name = bitcast->name();
+ }
+
+ return Status::OK();
+ }
+};
+
+// Remove Casts whose source type and destination type are equal.
+class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
+ : ArithmeticOptimizerStage(ctx) {}
+ ~RemoveRedundantCastStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node));
+ // Bypass Cast whose source type and destination type are equal.
+ if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
+ *simplified_node_name = node->input(0);
+ }
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -903,31 +1012,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
- // Remove inverse transposes.
- if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
- NodeDef* input = node_map_->GetNode(node->input(0));
- if (input->op() == node->op()) {
- const NodeDef* node_perm = node_map_->GetNode(node->input(1));
- const NodeDef* input_perm = node_map_->GetNode(input->input(1));
- // Try 32-bit indices.
- std::vector<int> node_perm_values;
- std::vector<int> input_perm_values;
- if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
- ValuesFromConstNode(*input_perm, &input_perm_values) &&
- AreInversePermutations(node_perm_values, input_perm_values)) {
- return input->input(0);
- }
- // Try 64-bit indices.
- std::vector<int64> node_perm_values64;
- std::vector<int64> input_perm_values64;
- if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
- ValuesFromConstNode(*input_perm, &input_perm_values64) &&
- AreInversePermutations(node_perm_values64, input_perm_values64)) {
- return input->input(0);
- }
- }
- }
-
if (node->op() == "Reshape") {
// Reshape
// ^
@@ -1024,32 +1108,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
- if (node->op() == "Bitcast") {
- NodeDef* bitcast = node_map_->GetNode(node->name());
- // Bypass bitcasts whose source type and destination type are equal.
- if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) {
- return bitcast->input(0);
- }
-
- const NodeDef* operand = node_map_->GetNode(bitcast->input(0));
- if (operand->op() == bitcast->op()) {
- // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
- bitcast->set_input(0, operand->input(0));
- SetSourceDataType(GetSourceDataType(*operand), bitcast);
- node_map_->UpdateInput(bitcast->name(), bitcast->input(0),
- operand->input(0));
- nodes_to_simplify->PushBack(bitcast);
- return bitcast->name();
- }
- }
-
- if (node->op() == "Cast") {
- // Bypass casts whose source type and destination type are equal.
- if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
- return node->input(0);
- }
- }
-
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
@@ -1391,11 +1449,22 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
- // Add/AddN tree rewrites
- if (options_.enable_add_to_addn_combining) {
+ if (options_.combine_add_to_addn) {
stages.push_back(
std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
}
+ if (options_.remove_inverse_transpose) {
+ stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+ new RemoveInverseTranspose(ctx)));
+ }
+ if (options_.remove_redundant_bitcast) {
+ stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+ new RemoveRedundantBitcastStage(ctx)));
+ }
+ if (options_.remove_redundant_cast) {
+ stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+ new RemoveRedundantCastStage(ctx)));
+ }
VLOG(1) << "Simplify arithmetic ops using " << stages.size()
<< " arithmetic optimization stages";
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 9cff8ca9d0..787084454d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -55,14 +55,16 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
- // rewrite a tree of Add/AddN ops with a single AddN
- bool enable_add_to_addn_combining;
+ bool combine_add_to_addn = true;
+ bool remove_inverse_transpose = true;
+ bool remove_redundant_bitcast = true;
+ bool remove_redundant_cast = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
- return {/*enable_add_to_addn_combining*/ true};
+ return ArithmeticOptimizerOptions();
}
};
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index a56351c18a..98842b29f1 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -49,7 +50,7 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
}
} // namespace
-class ArithmeticOptimizerTest : public ::testing::Test {
+class ArithmeticOptimizerTest : public GrapplerTest {
protected:
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
// longer have any output consumers.
@@ -63,14 +64,32 @@ class ArithmeticOptimizerTest : public ::testing::Test {
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
- ArithmeticOptimizer::ArithmeticOptimizerOptions options{
- /*enable_add_to_addn_combining*/ false};
+ ArithmeticOptimizer::ArithmeticOptimizerOptions options;
+ options.combine_add_to_addn = false;
+ options.remove_inverse_transpose = false;
+ options.remove_redundant_bitcast = false;
+ options.remove_redundant_cast = false;
optimizer->options_ = options;
}
- void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+ void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
- optimizer->options_.enable_add_to_addn_combining = true;
+ optimizer->options_.combine_add_to_addn = true;
+ }
+
+ void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_inverse_transpose = true;
+ }
+
+ void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_redundant_bitcast = true;
+ }
+
+ void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_redundant_cast = true;
}
};
@@ -658,9 +677,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
- EXPECT_EQ(0, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Reshape"; }));
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
@@ -682,9 +699,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
- EXPECT_EQ(1, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Reshape"; }));
+ EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
@@ -704,9 +719,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
- EXPECT_EQ(1, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Reshape"; }));
+ EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
@@ -737,9 +750,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
- EXPECT_EQ(1, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Reshape"; }));
+ EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
@@ -826,10 +837,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInverseTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
std::set<string> nodes_after_optimization;
for (const NodeDef& node : output.node()) {
@@ -859,10 +869,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInverseTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
for (const NodeDef& node : output.node()) {
if (node.op() == "Concat") {
@@ -886,10 +895,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInverseTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
const NodeDef* outputs_node = node_map.GetNode("outputs");
@@ -915,10 +925,9 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInverseTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
EXPECT_EQ(6, output.node_size());
}
@@ -1133,10 +1142,10 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output inputs =
- ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3}));
- Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
- Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
+ Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
+ ops::Placeholder::Shape({2, 3}));
+ Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
+ Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
GrapplerItem item;
@@ -1144,18 +1153,22 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantBitcast(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+ NodeMap node_map(&output);
- EXPECT_EQ(1, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Bitcast"; }));
+ // Bitcasts combined into a single op and inputs redirected to updated Bitcast
+ EXPECT_EQ(3, output.node_size());
+ EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
}
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
+ Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
+ ops::Placeholder::Shape({2, 3}));
Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
@@ -1163,33 +1176,42 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantBitcast(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+ NodeMap node_map(&output);
- EXPECT_EQ(0, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Bitcast"; }));
+ // Bitcasts removed and inputs redirected to outputs
+ EXPECT_EQ(2, output.node_size());
+ EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
+ Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
+ ops::Placeholder::Shape({2, 3}));
Output cast = ops::Cast(s, inputs, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantCast(&optimizer);
- EXPECT_EQ(0, std::count_if(
- output.node().begin(), output.node().end(),
- [](const NodeDef& node) { return node.op() == "Cast"; }));
+ OptimizeAndPrune(&optimizer, &item, &output);
+ NodeMap node_map(&output);
+
+ // Cast removed and inputs redirected to outputs
+ EXPECT_EQ(2, output.node_size());
+ EXPECT_EQ(0, CountOpNodes(output, "Cast"));
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
@@ -1211,7 +1233,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
GraphDef output;
ArithmeticOptimizer optimizer;
- EnableAddToAddNCombining(&optimizer);
+ EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@@ -1266,7 +1288,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
GraphDef output;
ArithmeticOptimizer optimizer;
- EnableAddToAddNCombining(&optimizer);
+ EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@@ -1329,7 +1351,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
GraphDef output;
ArithmeticOptimizer optimizer;
- EnableAddToAddNCombining(&optimizer);
+ EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 3dbad40cae..939031c44b 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -147,6 +147,22 @@ cc_library(
],
)
+tf_cc_test(
+ name = "grappler_test_test",
+ size = "small",
+ srcs = ["grappler_test_test.cc"],
+ deps = [
+ ":grappler_test",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:utils",
+ ],
+)
+
cc_library(
name = "functions",
srcs = [
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 79b2aa2808..89c3aa82bf 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -90,5 +90,20 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) {
}
}
+bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
+ const string& src,
+ const string& dst, int position) {
+ const NodeDef* src_node = node_map.GetNode(src);
+ const NodeDef* dst_node = node_map.GetNode(dst);
+ EXPECT_TRUE(src_node != nullptr) << src << " node not found";
+ EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
+ return src_node && dst_node && dst_node->input(position) == src_node->name();
+}
+
+int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
+ return std::count_if(graph.node().begin(), graph.node().end(),
+ [&op](const NodeDef& node) { return node.op() == op; });
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index fd6809b6e2..3df6625d5c 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -37,6 +38,13 @@ class GrapplerTest : public ::testing::Test {
const std::vector<string>& inputs, GraphDef* graph);
void CompareGraphs(GraphDef want, GraphDef got);
+
+ // Check if node 'src' is directly connected to the input($position) of 'dst'.
+ bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
+ const string& dst, int position = 0);
+
+ // Count nodes of the given op-type in a graph.
+ int CountOpNodes(const GraphDef& graph, const string& op);
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/utils/grappler_test_test.cc b/tensorflow/core/grappler/utils/grappler_test_test.cc
new file mode 100644
index 0000000000..677fa5a798
--- /dev/null
+++ b/tensorflow/core/grappler/utils/grappler_test_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+// TODO(ezhulenev): add tests for all methods in GrapplerTest
+class GrapplerTestTest : public GrapplerTest {};
+
+TEST_F(GrapplerTestTest, CompareIdenticalGraphs) {
+ tensorflow::Scope s1 = tensorflow::Scope::NewRootScope();
+ auto s1_a = ops::Variable(s1.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto s1_b = ops::Variable(s1.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto s1_add = ops::Add(s1.WithOpName("Add_1"), s1_a, s1_b);
+
+ tensorflow::Scope s2 = tensorflow::Scope::NewRootScope();
+ auto s2_a = ops::Variable(s2.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto s2_b = ops::Variable(s2.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto s2_add = ops::Add(s2.WithOpName("Add_1"), s2_a, s2_b);
+
+ GraphDef graph1;
+ TF_ASSERT_OK(s1.ToGraphDef(&graph1));
+
+ GraphDef graph2;
+ TF_ASSERT_OK(s2.ToGraphDef(&graph2));
+
+ CompareGraphs(graph1, graph2);
+}
+
+TEST_F(GrapplerTestTest, CheckNodesConnectivity) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto add_1 = ops::Add(s.WithOpName("Add_1"), a, b);
+ auto add_2 = ops::Add(s.WithOpName("Add_2"), add_1, b);
+
+ GraphDef graph;
+ TF_ASSERT_OK(s.ToGraphDef(&graph));
+
+ NodeMap node_map(&graph);
+
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "a", "Add_1", 0));
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_1", 1));
+ EXPECT_FALSE(IsNodesDirectlyConnected(node_map, "a", "Add_2", 0));
+ EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_2", 1));
+}
+
+TEST_F(GrapplerTestTest, CountOpNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
+
+ auto mul_ab = ops::Mul(s.WithOpName("Mull_ab"), a, b);
+ auto mul_bc = ops::Mul(s.WithOpName("Mull_bc"), a, b);
+
+ InputList inputs{
+ Output(add_ab),
+ Output(add_bc),
+ Output(mul_ab),
+ Output(mul_bc),
+ };
+ auto add_all = ops::AddN(s.WithOpName("Add_all"), inputs);
+
+ GraphDef graph;
+ TF_ASSERT_OK(s.ToGraphDef(&graph));
+
+ EXPECT_EQ(2, CountOpNodes(graph, "Add"));
+ EXPECT_EQ(2, CountOpNodes(graph, "Mul"));
+ EXPECT_EQ(1, CountOpNodes(graph, "AddN"));
+ EXPECT_EQ(0, CountOpNodes(graph, "Transpose"));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow \ No newline at end of file