aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/debug_stripper_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc116
1 files changed, 89 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index aacd55f136..3f11febc64 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -29,14 +29,13 @@ namespace {
class DebugStripperTest : public GrapplerTest {};
TEST_F(DebugStripperTest, OutputEqualToInput) {
- constexpr char device[] = "/device:CPU:0";
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({}));
+ Output add = ops::Add(s, x, y);
+ Output result = ops::Identity(s, add);
GrapplerItem item;
- item.graph = test::function::GDef(
- {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("y", "XTimesTwo", {"x"}, {}, device),
- test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)},
- {});
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
DebugStripper optimizer;
GraphDef output;
@@ -45,19 +44,17 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
}
TEST_F(DebugStripperTest, StripAssertFromGraph) {
- constexpr char device[] = "/device:CPU:0";
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), greaterequal, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("z").WithControlDependencies({assert.operation}), x, y);
GrapplerItem item;
- item.graph = test::function::GDef(
- {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("GreaterEqual", "GreaterEqual", {"x", "y"},
- {{"T", DT_FLOAT}}, device),
- test::function::NDef("Assert", "Assert", {"GreaterEqual"},
- {{"T", DT_FLOAT}}, device),
- test::function::NDef("z", "Add", {"x", "y", "^Assert"}, {}, device)},
- {});
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
DebugStripper optimizer;
GraphDef output;
@@ -68,31 +65,27 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) {
if (node.name() == "x") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "y") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "GreaterEqual") {
count++;
EXPECT_EQ("GreaterEqual", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Assert") {
count++;
EXPECT_EQ("NoOp", node.op());
- EXPECT_EQ(device, node.device());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("^GreaterEqual", node.input(0));
- EXPECT_EQ(0, node.attr_size());
+ EXPECT_EQ("^x", node.input(1));
+ EXPECT_EQ("^y", node.input(2));
} else if (node.name() == "z") {
count++;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
@@ -100,6 +93,75 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) {
}
}
EXPECT_EQ(5, count);
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ Tensor y_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ y_t.flat<float>()(0) = 0.5f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
+TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo");
+ auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo");
+ Output add = ops::Add(s.WithOpName("z"), check1, check2);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "CheckNumerics1") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "CheckNumerics2") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("CheckNumerics1", node.input(0));
+ EXPECT_EQ("CheckNumerics2", node.input(1));
+ }
+ }
+ EXPECT_EQ(5, count);
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ Tensor y_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ y_t.flat<float>()(0) = 0.5f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
} // namespace