aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/profiling
diff options
context:
space:
mode:
authorGravatar Shashi Shekhar <shashishekhar@google.com>2018-05-23 17:14:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 17:17:17 -0700
commit2307db76a2a07c7af6581e0ef4c6a5a0b83921f4 (patch)
treea056eb11e2a8698dd0a5c8eb6aa3587c0ec71ca7 /tensorflow/contrib/lite/profiling
parentdac1f124020234fe24e8893a981b15395d0c6de8 (diff)
Refactor StatSummarizer extract common functionality without proto dependencies.
PiperOrigin-RevId: 197816405
Diffstat (limited to 'tensorflow/contrib/lite/profiling')
-rw-r--r--tensorflow/contrib/lite/profiling/BUILD27
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.cc140
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.h58
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer_test.cc116
4 files changed, 341 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD
index 15999e5d41..c86be65ca7 100644
--- a/tensorflow/contrib/lite/profiling/BUILD
+++ b/tensorflow/contrib/lite/profiling/BUILD
@@ -31,6 +31,33 @@ cc_library(
copts = common_copts,
)
+cc_library(
+ name = "profile_summarizer",
+ srcs = ["profile_summarizer.cc"],
+ hdrs = ["profile_summarizer.h"],
+ deps = [
+ ":profiler",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:stats_calculator_portable",
+ ],
+)
+
+cc_test(
+ name = "profile_summarizer_test",
+ srcs = ["profile_summarizer_test.cc"],
+ deps = [
+ ":profile_summarizer",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
cc_test(
name = "profile_buffer_test",
srcs = ["profile_buffer_test.cc"],
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
new file mode 100644
index 0000000000..788f6922d2
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+
+#include <sstream>
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace profiling {
+namespace {
+
+using Detail = tensorflow::StatsCalculator::Detail;
+
+struct OperatorDetails {
+ string name;
+ std::vector<string> inputs;
+ std::vector<string> outputs;
+};
+
+string GetTensorName(const tflite::Interpreter& interpreter, int tensor_index) {
+ const auto tensor = interpreter.tensor(tensor_index);
+ if (tensor == nullptr || tensor->name == nullptr) {
+ return "Unknown";
+ }
+ return tensor->name;
+}
+std::vector<string> GetTensorNames(const tflite::Interpreter& interpreter,
+ const TfLiteIntArray* tensor_indices) {
+ std::vector<string> tensors;
+ tensors.reserve(tensor_indices->size);
+ for (int i = 0; i < tensor_indices->size; i++) {
+ tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i]));
+ }
+ return tensors;
+}
+
+string ToString(const std::vector<string>& str_vector) {
+ std::stringstream stream;
+ stream << "[";
+ bool first = true;
+ for (const auto& s : str_vector) {
+ if (!first) {
+ stream << ", ";
+ } else {
+ first = false;
+ }
+ stream << s;
+ }
+ stream << "]";
+ return stream.str();
+}
+
+OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
+ int node_index) {
+ auto node_reg = interpreter.node_and_registration(node_index);
+ auto inputs = node_reg->first.inputs;
+ auto outputs = node_reg->first.outputs;
+ int code = node_reg->second.builtin_code;
+ const char* op_name = nullptr;
+ if (code == tflite::BuiltinOperator_CUSTOM) {
+ const char* custom_name = node_reg->second.custom_name;
+ op_name = custom_name ? custom_name : "UnknownCustomOp";
+ } else {
+ op_name = tflite::EnumNamesBuiltinOperator()[code];
+ }
+ OperatorDetails details;
+ details.name = op_name;
+ details.inputs = GetTensorNames(interpreter, inputs);
+ details.outputs = GetTensorNames(interpreter, outputs);
+ return details;
+}
+
+} // namespace
+
+ProfileSummarizer::ProfileSummarizer()
+ : stats_calculator_(new ::tensorflow::StatsCalculator(
+ tensorflow::StatSummarizerOptions())) {}
+
+void ProfileSummarizer::ProcessProfiles(
+ const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter) {
+ std::vector<const ProfileEvent*> events;
+ std::copy_if(profile_stats.begin(), profile_stats.end(),
+ std::back_inserter(events), [](const ProfileEvent* e) {
+ return e->event_type ==
+ ProfileEvent::EventType::OPERATOR_INVOKE_EVENT &&
+ e->end_timestamp_us >= e->begin_timestamp_us;
+ });
+ // Sort with begin_time.
+ std::sort(events.begin(), events.end(),
+ [](const ProfileEvent* const& a, const ProfileEvent* const& b) {
+ return a->begin_timestamp_us < b->begin_timestamp_us;
+ });
+ if (events.empty()) {
+ return;
+ }
+
+ int64_t base_start_us = events[0]->begin_timestamp_us;
+ int node_num = 0;
+ int64_t curr_total_us = 0;
+ std::map<std::string, Detail> details;
+ for (auto event : events) {
+ auto op_details = GetOperatorDetails(interpreter, event->event_metadata);
+ auto node_name = ToString(op_details.outputs);
+ auto result = details.emplace(node_name, Detail());
+ Detail* detail = &(result.first->second);
+ detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us);
+ int64_t node_exec_time =
+ event->end_timestamp_us - event->begin_timestamp_us;
+ detail->rel_end_us.UpdateStat(node_exec_time);
+ curr_total_us += node_exec_time;
+ ++node_num;
+
+ if (result.second) {
+ detail->name = node_name;
+ detail->type = op_details.name;
+ detail->run_order = node_num;
+ detail->times_called = 0;
+ }
+ ++detail->times_called;
+ }
+ stats_calculator_->UpdateDetails(details);
+ stats_calculator_->UpdateRunTotalUs(curr_total_us);
+}
+} // namespace profiling
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h
new file mode 100644
index 0000000000..6fe6ca04f5
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/core/util/stats_calculator.h"
+
+namespace tflite {
+namespace profiling {
+
+// Creates a summary of operator invocations in the interpreter.
+class ProfileSummarizer {
+ public:
+ ProfileSummarizer();
+ virtual ~ProfileSummarizer() {}
+
+ // Process profile events to update statistics for operator invocations.
+ void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter);
+
+ // Returns a string detailing the accumulated runtime stats in a tab-separated
+ // format which can be pasted into a spreadsheet for further analysis.
+ std::string GetOutputString() const {
+ return stats_calculator_->GetOutputString();
+ }
+
+ std::string GetShortSummary() const {
+ return stats_calculator_->GetShortSummary();
+ }
+
+ // Prints the string returned by GetOutputString().
+ void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
+
+ private:
+ std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
+};
+
+} // namespace profiling
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
new file mode 100644
index 0000000000..35cf780713
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+namespace profiling {
+
+namespace {
+
+TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
+ const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
+
+ TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
+
+ int32_t* output_data = output->data.i32;
+ *output_data = *(input1->data.i32) + *(input2->data.i32);
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* RegisterSimpleOp() {
+ static TfLiteRegistration registration = {nullptr,
+ nullptr,
+ nullptr,
+ SimpleOpEval,
+ tflite::BuiltinOperator_CUSTOM,
+ "SimpleOpEval",
+ 1};
+ return &registration;
+}
+
+class SimpleOpModel : public SingleOpModel {
+ public:
+ void Init();
+ tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
+ void SetInputs(int32_t x, int32_t y) {
+ PopulateTensor(inputs_[0], {x});
+ PopulateTensor(inputs_[1], {y});
+ }
+ int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
+
+ private:
+ int inputs_[2];
+ int output_;
+};
+
+void SimpleOpModel::Init() {
+ inputs_[0] = AddInput({TensorType_INT32, {1}});
+ inputs_[1] = AddInput({TensorType_INT32, {1}});
+ output_ = AddOutput({TensorType_INT32, {}});
+ SetCustomOp("SimpleAdd", {}, RegisterSimpleOp);
+ BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
+}
+
+TEST(ProfileSummarizerTest, Empty) {
+ ProfileSummarizer summarizer;
+ std::string output = summarizer.GetOutputString();
+ EXPECT_GT(output.size(), 0);
+}
+
+#ifdef TFLITE_PROFILING_ENABLED
+TEST(ProfileSummarizerTest, Interpreter) {
+ Profiler profiler;
+ SimpleOpModel m;
+ m.Init();
+ auto interpreter = m.GetInterpreter();
+ interpreter->SetProfiler(&profiler);
+ profiler.StartProfiling();
+ m.SetInputs(1, 2);
+ m.Invoke();
+ // 3 = 1 + 2
+ EXPECT_EQ(m.GetOutput(), 3);
+ profiler.StopProfiling();
+ ProfileSummarizer summarizer;
+ auto events = profiler.GetProfileEvents();
+ EXPECT_EQ(1, events.size());
+ summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
+ auto output = summarizer.GetOutputString();
+ // TODO(shashishekhar): Add a better test here.
+ ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output;
+}
+#endif
+
+} // namespace
+} // namespace profiling
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}