aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/profiling
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 13:48:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 13:51:07 -0700
commit4173a7ac400f95ce128b83bded2db2742beb60c8 (patch)
treedd274db7097a16749c0d1ca930ea4809507a9c1a /tensorflow/contrib/lite/profiling
parent8d9ff7f792267bed942684091da215a84eae8065 (diff)
Allow ops to annotate their own profile info.
PiperOrigin-RevId: 201579919
Diffstat (limited to 'tensorflow/contrib/lite/profiling')
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.cc5
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer_test.cc50
2 files changed, 50 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
index 45388b500c..c37a096588 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
@@ -78,8 +78,13 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
} else {
op_name = tflite::EnumNamesBuiltinOperator()[code];
}
+ const char* profiling_string =
+ interpreter.OpProfilingString(node_reg->second, &node_reg->first);
OperatorDetails details;
details.name = op_name;
+ if (profiling_string) {
+ details.name += ":" + string(profiling_string);
+ }
details.inputs = GetTensorNames(interpreter, inputs);
details.outputs = GetTensorNames(interpreter, outputs);
return details;
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
index 35cf780713..67a5eecfa0 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
@@ -31,6 +31,7 @@ namespace profiling {
namespace {
+#ifdef TFLITE_PROFILING_ENABLED
TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
@@ -42,20 +43,35 @@ TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+const char* SimpleOpProfilingString(const TfLiteContext* context,
+ const TfLiteNode* node) {
+ return "Profile";
+}
+
TfLiteRegistration* RegisterSimpleOp() {
+ static TfLiteRegistration registration = {
+ nullptr, nullptr, nullptr,
+ SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM,
+ "SimpleOpEval", 1};
+ return &registration;
+}
+
+TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
static TfLiteRegistration registration = {nullptr,
nullptr,
nullptr,
SimpleOpEval,
+ SimpleOpProfilingString,
tflite::BuiltinOperator_CUSTOM,
"SimpleOpEval",
1};
return &registration;
}
+#endif
class SimpleOpModel : public SingleOpModel {
public:
- void Init();
+ void Init(const std::function<TfLiteRegistration*()>& registration);
tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
void SetInputs(int32_t x, int32_t y) {
PopulateTensor(inputs_[0], {x});
@@ -68,11 +84,12 @@ class SimpleOpModel : public SingleOpModel {
int output_;
};
-void SimpleOpModel::Init() {
+void SimpleOpModel::Init(
+ const std::function<TfLiteRegistration*()>& registration) {
inputs_[0] = AddInput({TensorType_INT32, {1}});
inputs_[1] = AddInput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT32, {}});
- SetCustomOp("SimpleAdd", {}, RegisterSimpleOp);
+ SetCustomOp("SimpleAdd", {}, registration);
BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
}
@@ -86,7 +103,28 @@ TEST(ProfileSummarizerTest, Empty) {
TEST(ProfileSummarizerTest, Interpreter) {
Profiler profiler;
SimpleOpModel m;
- m.Init();
+ m.Init(RegisterSimpleOp);
+ auto interpreter = m.GetInterpreter();
+ interpreter->SetProfiler(&profiler);
+ profiler.StartProfiling();
+ m.SetInputs(1, 2);
+ m.Invoke();
+ // 3 = 1 + 2
+ EXPECT_EQ(m.GetOutput(), 3);
+ profiler.StopProfiling();
+ ProfileSummarizer summarizer;
+ auto events = profiler.GetProfileEvents();
+ EXPECT_EQ(1, events.size());
+ summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
+ auto output = summarizer.GetOutputString();
+ // TODO(shashishekhar): Add a better test here.
+ ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
+}
+
+TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
+ Profiler profiler;
+ SimpleOpModel m;
+ m.Init(RegisterSimpleOpWithProfilingDetails);
auto interpreter = m.GetInterpreter();
interpreter->SetProfiler(&profiler);
profiler.StartProfiling();
@@ -101,8 +139,10 @@ TEST(ProfileSummarizerTest, Interpreter) {
summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
auto output = summarizer.GetOutputString();
// TODO(shashishekhar): Add a better test here.
- ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output;
+ ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos)
+ << output;
}
+
#endif
} // namespace