aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/benchmark
diff options
context:
space:
mode:
authorGravatar Shashi Shekhar <shashishekhar@google.com>2018-05-05 11:55:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 15:45:26 -0700
commit62ed0aa37099e07720880a72a285304d34512cba (patch)
treed6553662652f7abfc81ae4cac2b89a92d8c908c2 /tensorflow/tools/benchmark
parentf3c21911bca9c1ef01560dfd7609020d7f85f52b (diff)
Allow benchmark model graph to be specified in text proto format.
PiperOrigin-RevId: 195547670
Diffstat (limited to 'tensorflow/tools/benchmark')
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc4
-rw-r--r--tensorflow/tools/benchmark/benchmark_model_test.cc55
2 files changed, 47 insertions, 12 deletions
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index 15523028c7..eeb1fab40c 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -262,6 +262,10 @@ Status InitializeSession(int num_threads, const string& graph,
tensorflow::GraphDef tensorflow_graph;
Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get());
if (!s.ok()) {
+ s = ReadTextProto(Env::Default(), graph, graph_def->get());
+ }
+
+ if (!s.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << s;
return s;
}
diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc
index 16ab2ff66e..6813045d63 100644
--- a/tensorflow/tools/benchmark/benchmark_model_test.cc
+++ b/tensorflow/tools/benchmark/benchmark_model_test.cc
@@ -26,30 +26,36 @@ limitations under the License.
namespace tensorflow {
namespace {
-TEST(BenchmarkModelTest, InitializeAndRun) {
- const string dir = testing::TmpDir();
- const string filename_pb = io::JoinPath(dir, "graphdef.pb");
-
+void CreateTestGraph(const ::tensorflow::Scope& root,
+ benchmark_model::InputLayerInfo* input,
+ string* output_name, GraphDef* graph_def) {
// Create a simple graph and write it to filename_pb.
const int input_width = 400;
const int input_height = 10;
- benchmark_model::InputLayerInfo input;
- input.shape = TensorShape({input_width, input_height});
- input.data_type = DT_FLOAT;
+ input->shape = TensorShape({input_width, input_height});
+ input->data_type = DT_FLOAT;
const TensorShape constant_shape({input_height, input_width});
Tensor constant_tensor(DT_FLOAT, constant_shape);
test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
- auto root = Scope::NewRootScope().ExitOnError();
auto placeholder =
- ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape));
- input.name = placeholder.node()->name();
+ ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape));
+ input->name = placeholder.node()->name();
auto m = ops::MatMul(root, placeholder, constant_tensor);
- const string output_name = m.node()->name();
+ *output_name = m.node()->name();
+ TF_ASSERT_OK(root.ToGraphDef(graph_def));
+}
+
+TEST(BenchmarkModelTest, InitializeAndRun) {
+ const string dir = testing::TmpDir();
+ const string filename_pb = io::JoinPath(dir, "graphdef.pb");
+ auto root = Scope::NewRootScope().ExitOnError();
+ benchmark_model::InputLayerInfo input;
+ string output_name;
GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ CreateTestGraph(root, &input, &output_name, &graph_def);
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(
@@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
ASSERT_EQ(num_runs, 10);
}
+TEST(BenchmarkModeTest, TextProto) {
+ const string dir = testing::TmpDir();
+ const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt");
+ auto root = Scope::NewRootScope().ExitOnError();
+
+ benchmark_model::InputLayerInfo input;
+ string output_name;
+ GraphDef graph_def;
+ CreateTestGraph(root, &input, &output_name, &graph_def);
+ TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def));
+
+ std::unique_ptr<Session> session;
+ std::unique_ptr<GraphDef> loaded_graph_def;
+ TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session,
+ &loaded_graph_def));
+ std::unique_ptr<StatSummarizer> stats;
+ stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get())));
+ int64 time;
+ int64 num_runs = 0;
+ TF_ASSERT_OK(benchmark_model::TimeMultipleRuns(
+ 0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(),
+ &time, &num_runs));
+ ASSERT_EQ(num_runs, 10);
+}
+
} // namespace
} // namespace tensorflow