diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-21 13:48:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 13:51:07 -0700 |
commit | 4173a7ac400f95ce128b83bded2db2742beb60c8 (patch) | |
tree | dd274db7097a16749c0d1ca930ea4809507a9c1a /tensorflow/contrib/lite/profiling | |
parent | 8d9ff7f792267bed942684091da215a84eae8065 (diff) |
Allow ops to annotate their own profile info.
PiperOrigin-RevId: 201579919
Diffstat (limited to 'tensorflow/contrib/lite/profiling')
-rw-r--r-- | tensorflow/contrib/lite/profiling/profile_summarizer.cc | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/profiling/profile_summarizer_test.cc | 50 |
2 files changed, 50 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 45388b500c..c37a096588 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -78,8 +78,13 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, } else { op_name = tflite::EnumNamesBuiltinOperator()[code]; } + const char* profiling_string = + interpreter.OpProfilingString(node_reg->second, &node_reg->first); OperatorDetails details; details.name = op_name; + if (profiling_string) { + details.name += ":" + string(profiling_string); + } details.inputs = GetTensorNames(interpreter, inputs); details.outputs = GetTensorNames(interpreter, outputs); return details; diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc index 35cf780713..67a5eecfa0 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -31,6 +31,7 @@ namespace profiling { namespace { +#ifdef TFLITE_PROFILING_ENABLED TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); @@ -42,20 +43,35 @@ TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +const char* SimpleOpProfilingString(const TfLiteContext* context, + const TfLiteNode* node) { + return "Profile"; +} + TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = { + nullptr, nullptr, nullptr, + SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", 1}; + return ®istration; +} + +TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() { static TfLiteRegistration registration = {nullptr, nullptr, nullptr, SimpleOpEval, + SimpleOpProfilingString, tflite::BuiltinOperator_CUSTOM, "SimpleOpEval", 1}; return ®istration; } +#endif class SimpleOpModel : public SingleOpModel { public: - void Init(); + void Init(const std::function<TfLiteRegistration*()>& registration); tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } void SetInputs(int32_t x, int32_t y) { PopulateTensor(inputs_[0], {x}); @@ -68,11 +84,12 @@ class SimpleOpModel : public SingleOpModel { int output_; }; -void SimpleOpModel::Init() { +void SimpleOpModel::Init( + const std::function<TfLiteRegistration*()>& registration) { inputs_[0] = AddInput({TensorType_INT32, {1}}); inputs_[1] = AddInput({TensorType_INT32, {1}}); output_ = AddOutput({TensorType_INT32, {}}); - SetCustomOp("SimpleAdd", {}, RegisterSimpleOp); + SetCustomOp("SimpleAdd", {}, registration); BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); } @@ -86,7 +103,28 @@ TEST(ProfileSummarizerTest, Empty) { TEST(ProfileSummarizerTest, Interpreter) { Profiler profiler; SimpleOpModel m; - m.Init(); + m.Init(RegisterSimpleOp); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output; +} + +TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { + Profiler profiler; + SimpleOpModel m; + m.Init(RegisterSimpleOpWithProfilingDetails); auto interpreter = m.GetInterpreter(); interpreter->SetProfiler(&profiler); profiler.StartProfiling(); @@ -101,8 +139,10 @@ TEST(ProfileSummarizerTest, Interpreter) { summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); auto output = summarizer.GetOutputString(); // TODO(shashishekhar): Add a better test here. - ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output; + ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos) + << output; } + #endif } // namespace |