diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/debug_stripper_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/debug_stripper_test.cc | 116 |
1 files changed, 89 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc index aacd55f136..3f11febc64 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -29,14 +29,13 @@ namespace { class DebugStripperTest : public GrapplerTest {}; TEST_F(DebugStripperTest, OutputEqualToInput) { - constexpr char device[] = "/device:CPU:0"; + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({})); + Output add = ops::Add(s, x, y); + Output result = ops::Identity(s, add); GrapplerItem item; - item.graph = test::function::GDef( - {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("y", "XTimesTwo", {"x"}, {}, device), - test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)}, - {}); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); DebugStripper optimizer; GraphDef output; @@ -45,19 +44,17 @@ TEST_F(DebugStripperTest, OutputEqualToInput) { } TEST_F(DebugStripperTest, StripAssertFromGraph) { - constexpr char device[] = "/device:CPU:0"; + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto assert = ops::Assert(s.WithOpName("Assert"), greaterequal, {x, y}); + Output add = ops::Add( + s.WithOpName("z").WithControlDependencies({assert.operation}), x, y); GrapplerItem item; - item.graph = test::function::GDef( - {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("GreaterEqual", "GreaterEqual", {"x", "y"}, - {{"T", DT_FLOAT}}, device), - test::function::NDef("Assert", "Assert", {"GreaterEqual"}, - {{"T", DT_FLOAT}}, device), - test::function::NDef("z", "Add", {"x", "y", "^Assert"}, {}, device)}, - {}); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); DebugStripper optimizer; GraphDef output; @@ -68,31 +65,27 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) { if (node.name() == "x") { count++; EXPECT_EQ("Placeholder", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "y") { count++; EXPECT_EQ("Placeholder", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "GreaterEqual") { count++; EXPECT_EQ("GreaterEqual", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); } else if (node.name() == "Assert") { count++; EXPECT_EQ("NoOp", node.op()); - EXPECT_EQ(device, node.device()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(3, node.input_size()); EXPECT_EQ("^GreaterEqual", node.input(0)); - EXPECT_EQ(0, node.attr_size()); + EXPECT_EQ("^x", node.input(1)); + EXPECT_EQ("^y", node.input(2)); } else if (node.name() == "z") { count++; EXPECT_EQ("Add", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); @@ -100,6 +93,75 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) { } } EXPECT_EQ(5, count); + + Tensor x_t(DT_FLOAT, TensorShape({})); + Tensor y_t(DT_FLOAT, TensorShape({})); + x_t.flat<float>()(0) = 1.0f; + y_t.flat<float>()(0) = 0.5f; + std::vector<Tensor> expected = + EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}}); + std::vector<Tensor> optimized = + EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}}); + test::ExpectTensorEqual<float>(expected[0], optimized[0]); +} + +TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo"); + auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo"); + Output add = ops::Add(s.WithOpName("z"), check1, check2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "y") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "CheckNumerics1") { + count++; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ(1, node.attr_size()); + } else if (node.name() == "CheckNumerics2") { + count++; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ(1, node.attr_size()); + } else if (node.name() == "z") { + count++; + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("CheckNumerics1", node.input(0)); + EXPECT_EQ("CheckNumerics2", node.input(1)); + } + } + EXPECT_EQ(5, count); + + Tensor x_t(DT_FLOAT, TensorShape({})); + Tensor y_t(DT_FLOAT, TensorShape({})); + x_t.flat<float>()(0) = 1.0f; + y_t.flat<float>()(0) = 0.5f; + std::vector<Tensor> expected = + EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}}); + std::vector<Tensor> optimized = + EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}}); + test::ExpectTensorEqual<float>(expected[0], optimized[0]); } } // namespace |