aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/benchmark
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2016-06-13 17:25:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-13 18:33:26 -0700
commitace8e1b60847fe7771176fec2c8cb9c19a54ea03 (patch)
tree6fbb398b047cab3a82e2a35dbf5a625bbf244aa1 /tensorflow/tools/benchmark
parent36814258fe915d8e8617d8a9768b3166853021ce (diff)
Allow benchmark_model to optionally output benchmark proto.
Change: 124793729
Diffstat (limited to 'tensorflow/tools/benchmark')
-rw-r--r--tensorflow/tools/benchmark/BUILD4
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc28
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD
index e926a45673..5a7981e112 100644
--- a/tensorflow/tools/benchmark/BUILD
+++ b/tensorflow/tools/benchmark/BUILD
@@ -15,6 +15,7 @@ exports_files(["LICENSE"])
cc_library(
name = "benchmark_model_lib",
+ testonly = 1,
srcs = [
"benchmark_model.cc",
],
@@ -26,6 +27,7 @@ cc_library(
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
@@ -34,6 +36,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
],
}),
)
@@ -66,6 +69,7 @@ tf_cc_test(
# uses.
cc_binary(
name = "benchmark_model",
+ testonly = 1,
srcs = ["benchmark_model_main.cc"],
copts = tf_copts(),
linkopts = select({
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index cc6e50f43f..cff30453a9 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
+#include "tensorflow/core/util/reporter.h"
#include "tensorflow/core/util/stat_summarizer.h"
namespace tensorflow {
@@ -167,6 +168,7 @@ int Main(int argc, char** argv) {
int num_runs = 50;
string run_delay = "-1.0";
int num_threads = -1;
+ string benchmark_name = "";
const bool parse_result = ParseFlags(
&argc, argv, {
@@ -178,6 +180,7 @@ int Main(int argc, char** argv) {
Flag("num_runs", &num_runs), //
Flag("run_delay", &run_delay), //
Flag("num_threads", &num_threads), //
+ Flag("benchmark_name", &benchmark_name), //
});
if (!parse_result) {
@@ -199,6 +202,7 @@ int Main(int argc, char** argv) {
LOG(INFO) << "Num runs: [" << num_runs << "]";
LOG(INFO) << "Inter-run delay (seconds): [" << run_delay << "]";
LOG(INFO) << "Num threads: [" << num_threads << "]";
+ LOG(INFO) << "Benchmark name: [" << benchmark_name << "]";
std::unique_ptr<Session> session;
std::unique_ptr<StatSummarizer> stats;
@@ -219,9 +223,14 @@ int Main(int argc, char** argv) {
for (int i = 0; i < sizes.size(); ++i) {
input_shape.AddDim(sizes[i]);
}
+
+ const int64 start_time = Env::Default()->NowMicros();
Status time_status =
TimeMultipleRuns(sleep_seconds, num_runs, input_data_type, input_shape,
input_layer, output_layer, session.get(), stats.get());
+ const int64 end_time = Env::Default()->NowMicros();
+ const double wall_time = (end_time - start_time) / 1000000.0;
+
if (!time_status.ok()) {
LOG(ERROR) << "Timing failed with " << time_status;
return -1;
@@ -229,6 +238,25 @@ int Main(int argc, char** argv) {
stats->PrintStepStats();
+ if (!benchmark_name.empty()) {
+ // Compute the total number of values per input.
+ int64 total_size = 1;
+ for (int32 size : sizes) {
+ total_size *= size;
+ }
+
+ // Throughput in MB/s
+ const double throughput = DataTypeSize(input_data_type) * total_size *
+ num_runs / static_cast<double>(wall_time) /
+ (1024 * 1024);
+
+ // Report the stats.
+ TestReporter reporter(benchmark_name);
+ reporter.Initialize();
+ reporter.Benchmark(num_runs, -1.0, wall_time, throughput);
+ reporter.Close();
+ }
+
return 0;
}