aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2016-11-11 16:43:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 17:02:51 -0800
commit1743ad893859ea342c3c6b04a5292b763ae4aead (patch)
treeb02f61f10035aa0004e338084fe2841dcbbc9b18
parentf27dea2f016baf09040bf5aec705511486a3f205 (diff)
Add python wrapper for StatSummarizer.
Change: 138933733
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/stat_summarizer/BUILD38
-rw-r--r--tensorflow/contrib/stat_summarizer/__init__.py27
-rw-r--r--tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py70
-rw-r--r--tensorflow/core/util/stat_summarizer.cc18
-rw-r--r--tensorflow/core/util/stat_summarizer.h7
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/util/stat_summarizer.i82
11 files changed, 239 insertions, 8 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index ccef57b5a8..29f87329f8 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -117,6 +117,7 @@ filegroup(
"//tensorflow/contrib/slim/python/slim/data:all_files",
"//tensorflow/contrib/slim/python/slim/nets:all_files",
"//tensorflow/contrib/specs:all_files",
+ "//tensorflow/contrib/stat_summarizer:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensorboard:all_files",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 704de2605e..0aa4fb9f68 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -38,6 +38,7 @@ py_library(
"//tensorflow/contrib/slim",
"//tensorflow/contrib/slim:nets",
"//tensorflow/contrib/specs",
+ "//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/tensor_forest/hybrid:ops_lib",
"//tensorflow/contrib/tensorboard",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 0ded847cfa..a3c6d62fcd 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -40,6 +40,7 @@ from tensorflow.contrib import quantization
from tensorflow.contrib import rnn
from tensorflow.contrib import seq2seq
from tensorflow.contrib import slim
+from tensorflow.contrib import stat_summarizer
from tensorflow.contrib import tensor_forest
from tensorflow.contrib import tensorboard
from tensorflow.contrib import testing
diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD
new file mode 100644
index 0000000000..dd00db62b7
--- /dev/null
+++ b/tensorflow/contrib/stat_summarizer/BUILD
@@ -0,0 +1,38 @@
+# Description:
+# Contains a Python wrapper for the StatSummarizer C++ class.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+py_library(
+ name = "stat_summarizer_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+)
+
+tf_py_test(
+ name = "stat_summarizer_test",
+ size = "small",
+ srcs = ["python/stat_summarizer_test.py"],
+ additional_deps = [
+ ":stat_summarizer_py",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/stat_summarizer/__init__.py b/tensorflow/contrib/stat_summarizer/__init__.py
new file mode 100644
index 0000000000..32feb7edb9
--- /dev/null
+++ b/tensorflow/contrib/stat_summarizer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Exposes the Python wrapper for StatSummarizer utility class.
+
+The wrapper implementation is in tensorflow/python/util/stat_summarizer.i for
+technical reasons, but it should be accessed via tf.contrib.stat_summarizer.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer
+from tensorflow.python.pywrap_tensorflow import NewStatSummarizer
+from tensorflow.python.pywrap_tensorflow import StatSummarizer
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
new file mode 100644
index 0000000000..616be81e27
--- /dev/null
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -0,0 +1,70 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for StatSummarizer Python wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class StatSummarizerTest(tf.test.TestCase):
+
+ def testStatSummarizer(self):
+ with tf.Graph().as_default() as graph:
+ matrix1 = tf.constant([[3., 3.]])
+ matrix2 = tf.constant([[2.], [2.]])
+ product = tf.matmul(matrix1, matrix2)
+
+ graph_def = graph.as_graph_def()
+ ss = tf.contrib.stat_summarizer.NewStatSummarizer(
+ graph_def.SerializeToString())
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+
+ for _ in range(20):
+ run_metadata = tf.RunMetadata()
+ run_options = tf.RunOptions(
+ trace_level=tf.RunOptions.FULL_TRACE)
+ sess.run(product, options=run_options, run_metadata=run_metadata)
+
+ ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString())
+
+ output_string = ss.GetOutputString()
+
+ print(output_string)
+
+ # Test that the preliminary summary line was printed.
+ self.assertRegexpMatches(output_string, r"Total time")
+
+ # Test it recorded running the expected number of times.
+ self.assertRegexpMatches(output_string, r"count=20")
+ self.assertRegexpMatches(output_string, r"\n20 runs")
+
+ # Test that a header line got printed.
+ self.assertRegexpMatches(output_string, r"====== .* ======")
+
+ # Test that the MatMul node we added was analyzed.
+ self.assertRegexpMatches(output_string, r"MatMul")
+
+ # Test that a CDF summed to 100%
+ self.assertRegexpMatches(output_string, r"100\.")
+
+ tf.contrib.stat_summarizer.DeleteStatSummarizer(ss)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index 33cf5fe961..0b675eaac9 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -338,13 +338,19 @@ std::string StatSummarizer::GetStatsByOrderOfNodeDefinitions(
return stream.str();
}
+std::string StatSummarizer::GetOutputString() const {
+ std::stringstream stream;
+ stream << "Total time (us): " << run_total_micros_;
+ stream << GetTimingStatsByRunOrder();
+ stream << GetTimingStatsByTopDurations();
+ stream << "Total Memory (bytes): " << memory_;
+ stream << GetMemoryStatsByRunOrder();
+ stream << GetMemoryStatsByUsage();
+ return stream.str();
+}
+
void StatSummarizer::PrintStepStats() const {
- LOG(INFO) << "Total time (us): " << run_total_micros_;
- LOG(INFO) << GetTimingStatsByRunOrder();
- LOG(INFO) << GetTimingStatsByTopDurations();
- LOG(INFO) << "Total Memory (bytes): " << memory_;
- LOG(INFO) << GetMemoryStatsByRunOrder();
- LOG(INFO) << GetMemoryStatsByUsage();
+ LOG(INFO) << GetOutputString();
LOG(INFO);
}
diff --git a/tensorflow/core/util/stat_summarizer.h b/tensorflow/core/util/stat_summarizer.h
index 2dd3dbd5cd..7fc43d449b 100644
--- a/tensorflow/core/util/stat_summarizer.h
+++ b/tensorflow/core/util/stat_summarizer.h
@@ -111,8 +111,11 @@ class StatSummarizer {
// Adds another run's StepStats output to the aggregate counts.
void ProcessStepStats(const StepStats& step_stats);
- // Prints all the accumulated runtime stats in a tab-separated format which
- // can be pasted into a spreadsheet for further analysis.
+ // Returns a string detailing the accumulated runtime stats in a tab-separated
+ // format which can be pasted into a spreadsheet for further analysis.
+ std::string GetOutputString() const;
+
+ // Prints the string returned by GetOutputString().
void PrintStepStats() const;
// Prints the output tensor sizes and types for each node.
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c0533de1df..27655dff8e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1925,6 +1925,7 @@ tf_py_wrap_cc(
"util/kernel_registry.i",
"util/port.i",
"util/py_checkpoint_reader.i",
+ "util/stat_summarizer.i",
],
deps = [
":cpp_shape_inference",
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 9115ec891a..97c942320a 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -19,6 +19,7 @@ limitations under the License.
%include "tensorflow/python/util/port.i"
%include "tensorflow/python/util/py_checkpoint_reader.i"
+%include "tensorflow/python/util/stat_summarizer.i"
%include "tensorflow/python/lib/core/py_func.i"
diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i
new file mode 100644
index 0000000000..aa917cd2c5
--- /dev/null
+++ b/tensorflow/python/util/stat_summarizer.i
@@ -0,0 +1,82 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include <std_string.i>
+%include "tensorflow/python/lib/core/strings.i"
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+#include "tensorflow/python/lib/core/py_func.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+%}
+
+%ignoreall
+
+%unignore NewStatSummarizer;
+%unignore DeleteStatSummarizer;
+%unignore tensorflow;
+%unignore tensorflow::StatSummarizer;
+%unignore tensorflow::StatSummarizer::StatSummarizer;
+%unignore tensorflow::StatSummarizer::~StatSummarizer;
+%unignore tensorflow::StatSummarizer::Initialize;
+%unignore tensorflow::StatSummarizer::InitializeStr;
+%unignore tensorflow::StatSummarizer::ProcessStepStats;
+%unignore tensorflow::StatSummarizer::ProcessStepStatsStr;
+%unignore tensorflow::StatSummarizer::PrintStepStats;
+%unignore tensorflow::StatSummarizer::GetOutputString;
+
+
+%{
+tensorflow::StatSummarizer* NewStatSummarizer(
+ const string& graph_def_str) {
+ tensorflow::GraphDef graph_def;
+ graph_def.ParseFromString(graph_def_str);
+ return new tensorflow::StatSummarizer(graph_def);
+}
+%}
+
+
+%{
+void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
+ delete ss;
+}
+%}
+
+tensorflow::StatSummarizer* NewStatSummarizer(const string& graph_def_str);
+void DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
+
+%extend tensorflow::StatSummarizer {
+ void ProcessStepStatsStr(const string& step_stats_str) {
+ tensorflow::StepStats step_stats;
+ step_stats.ParseFromString(step_stats_str);
+ $self->ProcessStepStats(step_stats);
+}
+}
+
+%extend tensorflow::StatSummarizer {
+ StatSummarizer(const string& graph_def_str) {
+ tensorflow::GraphDef graph_def;
+ graph_def.ParseFromString(graph_def_str);
+ tensorflow::StatSummarizer* ss = new tensorflow::StatSummarizer(graph_def);
+ return ss;
+}
+}
+
+%include "tensorflow/core/util/stat_summarizer.h"
+%unignoreall