diff options
Diffstat (limited to 'tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc')
-rw-r--r-- | tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc | 46 |
1 files changed, 31 insertions, 15 deletions
diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc index b41d0770dc..3b40253954 100644 --- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc +++ b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc @@ -29,15 +29,16 @@ class TFProfAdvisorTest : public ::testing::Test { nullptr, nullptr)); stats_->AddNodeForTest( - "n1", CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 10, 2)); - stats_->AddNodeForTest("n2", CreateNode("n2", "Conv2D", {}, 20, 2)); + 0, CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 0, 10, 2)); + stats_->AddNodeForTest(0, CreateNode("n2", "Conv2D", {}, 0, 20, 2)); + stats_->BuildAllViews(); advisor_.reset(new Advisor(stats_.get())); } std::unique_ptr<TFGraphNode> CreateNode(const string& name, const string& type, std::map<string, string> attrs, - int64 start_miros, + int64 step, int64 start_miros, int64 end_rel_micros) { node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef())); NodeDef* def = node_defs_.back().get(); @@ -52,10 +53,10 @@ class TFProfAdvisorTest : public ::testing::Test { NodeExecStats node_stat; node_stat.set_all_start_micros(start_miros); node_stat.set_op_end_rel_micros(end_rel_micros); - node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0", node_stat); - node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:all", + node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0", node_stat); + node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:all", node_stat); - node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:0", + node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:0", node_stat); return node; } @@ -66,23 +67,38 @@ class TFProfAdvisorTest : public ::testing::Test { }; TEST_F(TFProfAdvisorTest, Basics) { - std::map<string, std::vector<string>> reports = advisor_->Advise(); - EXPECT_TRUE(reports.find("AcceleratorUtilizationChecker") != reports.end()); - EXPECT_TRUE(reports.find("OperationChecker") != reports.end()); + AdvisorOptionsProto options = Advisor::DefaultOptions(); + AdviceProto advice = advisor_->Advise(options); + EXPECT_TRUE(advice.checkers().find(kCheckers[0]) != advice.checkers().end()); + EXPECT_TRUE(advice.checkers().find(kCheckers[1]) != advice.checkers().end()); + EXPECT_TRUE(advice.checkers().find(kCheckers[2]) != advice.checkers().end()); } TEST_F(TFProfAdvisorTest, OperationChecker) { - std::map<string, std::vector<string>> reports = advisor_->Advise(); - EXPECT_EQ(reports["OperationChecker"].size(), 1); - EXPECT_TRUE(StringPiece(reports["OperationChecker"][0]).contains("NCHW")); + AdvisorOptionsProto options; + (*options.mutable_checkers())[kCheckers[1]]; + AdviceProto advice = advisor_->Advise(options); + EXPECT_EQ(advice.checkers().at(kCheckers[1]).reports_size(), 1); + EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[1]).reports(0)) + .contains("NCHW")); } TEST_F(TFProfAdvisorTest, UtilizationChecker) { - std::map<string, std::vector<string>> reports = advisor_->Advise(); - EXPECT_EQ(reports["AcceleratorUtilizationChecker"].size(), 1); - EXPECT_TRUE(StringPiece(reports["AcceleratorUtilizationChecker"][0]) + AdvisorOptionsProto options; + (*options.mutable_checkers())[kCheckers[0]]; + AdviceProto advice = advisor_->Advise(options); + EXPECT_EQ(advice.checkers().at(kCheckers[0]).reports_size(), 1); + EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[0]).reports(0)) .contains("low utilization")); } +TEST_F(TFProfAdvisorTest, ExpensiveOperationChecker) { + AdvisorOptionsProto options; + (*options.mutable_checkers())[kCheckers[2]]; + AdviceProto advice = advisor_->Advise(options); + EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[2]).reports(0)) + .contains("top 1 operation type: Conv2D")); +} + } // namespace tfprof } // namespace tensorflow |