aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc')
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc46
1 files changed, 31 insertions, 15 deletions
diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
index b41d0770dc..3b40253954 100644
--- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
+++ b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
@@ -29,15 +29,16 @@ class TFProfAdvisorTest : public ::testing::Test {
nullptr, nullptr));
stats_->AddNodeForTest(
- "n1", CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 10, 2));
- stats_->AddNodeForTest("n2", CreateNode("n2", "Conv2D", {}, 20, 2));
+ 0, CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 0, 10, 2));
+ stats_->AddNodeForTest(0, CreateNode("n2", "Conv2D", {}, 0, 20, 2));
+ stats_->BuildAllViews();
advisor_.reset(new Advisor(stats_.get()));
}
std::unique_ptr<TFGraphNode> CreateNode(const string& name,
const string& type,
std::map<string, string> attrs,
- int64 start_miros,
+ int64 step, int64 start_miros,
int64 end_rel_micros) {
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
NodeDef* def = node_defs_.back().get();
@@ -52,10 +53,10 @@ class TFProfAdvisorTest : public ::testing::Test {
NodeExecStats node_stat;
node_stat.set_all_start_micros(start_miros);
node_stat.set_op_end_rel_micros(end_rel_micros);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0", node_stat);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:all",
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0", node_stat);
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:all",
node_stat);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:0",
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:0",
node_stat);
return node;
}
@@ -66,23 +67,38 @@ class TFProfAdvisorTest : public ::testing::Test {
};
TEST_F(TFProfAdvisorTest, Basics) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_TRUE(reports.find("AcceleratorUtilizationChecker") != reports.end());
- EXPECT_TRUE(reports.find("OperationChecker") != reports.end());
+ AdvisorOptionsProto options = Advisor::DefaultOptions();
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_TRUE(advice.checkers().find(kCheckers[0]) != advice.checkers().end());
+ EXPECT_TRUE(advice.checkers().find(kCheckers[1]) != advice.checkers().end());
+ EXPECT_TRUE(advice.checkers().find(kCheckers[2]) != advice.checkers().end());
}
TEST_F(TFProfAdvisorTest, OperationChecker) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_EQ(reports["OperationChecker"].size(), 1);
- EXPECT_TRUE(StringPiece(reports["OperationChecker"][0]).contains("NCHW"));
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[1]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_EQ(advice.checkers().at(kCheckers[1]).reports_size(), 1);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[1]).reports(0))
+ .contains("NCHW"));
}
TEST_F(TFProfAdvisorTest, UtilizationChecker) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_EQ(reports["AcceleratorUtilizationChecker"].size(), 1);
- EXPECT_TRUE(StringPiece(reports["AcceleratorUtilizationChecker"][0])
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[0]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_EQ(advice.checkers().at(kCheckers[0]).reports_size(), 1);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[0]).reports(0))
.contains("low utilization"));
}
+TEST_F(TFProfAdvisorTest, ExpensiveOperationChecker) {
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[2]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[2]).reports(0))
+ .contains("top 1 operation type: Conv2D"));
+}
+
} // namespace tfprof
} // namespace tensorflow