aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-04 10:19:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-04 10:23:59 -0700
commitaf23ae65db2585f4a18d0bc5f21f15e94805aa4f (patch)
treea805f64d0a85fa29ff69d204634379b80cdbcbf1 /tensorflow/contrib/tfprof
parent11ec8b7cfdec0fd498182d0ad8f550b4a8ddaf13 (diff)
Migrating tfprof python API to tensorflow/python/profiler
Migrating tfprof c++ to tensorflow/core/profiler API changes: New tf.profiler namespace. Within tf.profiler namespace: tf.profiler.advise # One-shot advise function tf.profiler.profile # One-shot profile function tf.profiler.Profiler # Multi-step profile/advise class tf.profiler.write_op_log # Write profile for offline analysis PiperOrigin-RevId: 160901831
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/BUILD19
-rw-r--r--tensorflow/contrib/tfprof/README.md46
-rw-r--r--tensorflow/contrib/tfprof/__init__.py5
-rw-r--r--tensorflow/contrib/tfprof/model_analyzer.py47
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD116
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/__init__.py0
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD80
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/model_analyzer_testlib.py97
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py475
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py191
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py429
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py319
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py445
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py164
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py186
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py187
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py82
-rw-r--r--tensorflow/contrib/tfprof/tfprof_logger.py24
18 files changed, 112 insertions, 2800 deletions
diff --git a/tensorflow/contrib/tfprof/BUILD b/tensorflow/contrib/tfprof/BUILD
index 944d767e21..9b9215234e 100644
--- a/tensorflow/contrib/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/BUILD
@@ -8,12 +8,25 @@ py_library(
name = "tfprof",
srcs = [
"__init__.py",
+ "model_analyzer.py",
+ "tfprof_logger.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer",
- "//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
- "//tensorflow/python:util",
+ "//tensorflow/python/profiler:model_analyzer",
+ "//tensorflow/python/profiler:tfprof_logger",
],
)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index 4fa1ccea69..eefd88793e 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -1,26 +1,24 @@
# tfprof: TensorFlow Profiler and Beyond
-# Full Document in tensorflow/tools/tfprof/README.md
-
-Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang
-
-Consultants: Jon Shlens, Pete Warden
-
-###Major Features
-
-1. Measure model parameters, float operations, tensor shapes.
-2. Profile op execution times, requested memory size and device placement.
-3. Inspect checkpoint tensors' shapes and their values.
-4. Selectively group, filter, account and order ops.
-
-####tfprof supports 3 views to organize TensorFlow model profiles
-
- * code view: Stats are associated your Python codes and organized as call stacks.
- * scope view: Stats are organized as name scope hierarchies.
- * graph view: Stats are organized as Tensorflow Op graph.
-
-####For each view, there are 3 ways to display outputs:
-
- * stdout: Results are written to stdout.
- * timeline: Visualized in chrome browser as time series.
- * file: Results are dumped to file.
+<h1>Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`</h1>
+<h1>Full Document in tensorflow/core/profiler/README.md<h1>
+
+###Features
+
+* Profile model architectures
+ * parameters, tensor shapes, float operations, device placement, etc.
+* Profile model performance
+ * execution time, memory consumption
+ * Profile multiple steps.
+* Auto profile and advise.
+ * accelerator utilization check
+ * expensive operation check
+ * operation configuration check
+ * distributed runtime check (Not OSS)
+
+###Interfaces
+
+* Python API
+* Command Line
+* Visualization
+* C++ API (Not public, contact us if needed.)
diff --git a/tensorflow/contrib/tfprof/__init__.py b/tensorflow/contrib/tfprof/__init__.py
index f3952f6cb5..7a023e5d67 100644
--- a/tensorflow/contrib/tfprof/__init__.py
+++ b/tensorflow/contrib/tfprof/__init__.py
@@ -17,5 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
-from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
+# pylint: disable=unused-import
+from tensorflow.contrib.tfprof import model_analyzer
+from tensorflow.contrib.tfprof import tfprof_logger
diff --git a/tensorflow/contrib/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/model_analyzer.py
new file mode 100644
index 0000000000..04b5063218
--- /dev/null
+++ b/tensorflow/contrib/tfprof/model_analyzer.py
@@ -0,0 +1,47 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model Analyzer.
+
+Analyze model, including shape, params, time, memory, structure, etc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# Import the names here for existing users.
+# pylint: disable=unused-import
+from tensorflow.python.profiler.model_analyzer import advise as _advise
+from tensorflow.python.profiler.model_analyzer import ALL_ADVICE
+from tensorflow.python.profiler.model_analyzer import FLOAT_OPS_OPTIONS
+from tensorflow.python.profiler.model_analyzer import PRINT_ALL_TIMING_MEMORY
+from tensorflow.python.profiler.model_analyzer import profile as _profile
+from tensorflow.python.profiler.model_analyzer import Profiler
+from tensorflow.python.profiler.model_analyzer import TRAINABLE_VARS_PARAMS_STAT_OPTIONS
+
+_DEFAULT_PROFILE_OPTIONS = 0
+_DEFAULT_ADVISE_OPTIONS = 0
+
+
+def advise(graph, run_meta=None, tfprof_options=_DEFAULT_ADVISE_OPTIONS):
+ return _advise(graph, run_meta, tfprof_options)
+
+
+def print_model_analysis(graph,
+ run_meta=None,
+ op_log=None,
+ tfprof_cmd='scope',
+ tfprof_options=_DEFAULT_PROFILE_OPTIONS):
+ return _profile(graph, run_meta, op_log, tfprof_cmd, tfprof_options)
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
deleted file mode 100644
index b7c79edfca..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ /dev/null
@@ -1,116 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "py_test")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-
-py_library(
- name = "model_analyzer",
- srcs = ["model_analyzer.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":tfprof_logger",
- "//tensorflow/python:pywrap_tensorflow",
- "//tensorflow/tools/tfprof:protos_all_py",
- ],
-)
-
-cuda_py_test(
- name = "model_analyzer_test",
- srcs = ["model_analyzer_test.py"],
- additional_deps = [
- ":model_analyzer",
- "//tensorflow/contrib/tfprof/python/tools/tfprof/internal:model_analyzer_testlib",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- "//tensorflow/python:variables",
- ],
- tags = ["no_pip"],
-)
-
-cuda_py_test(
- name = "profiler_test",
- srcs = ["profiler_test.py"],
- additional_deps = [
- ":model_analyzer",
- "//tensorflow/contrib/tfprof/python/tools/tfprof/internal:model_analyzer_testlib",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- "//tensorflow/python:variables",
- ],
- tags = ["no_pip"],
-)
-
-py_library(
- name = "tfprof_logger",
- srcs = ["tfprof_logger.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- "//tensorflow/tools/tfprof:protos_all_py",
- "@six_archive//:six",
- ],
-)
-
-tf_py_test(
- name = "tfprof_logger_test",
- size = "small",
- srcs = ["tfprof_logger_test.py"],
- additional_deps = [
- ":tfprof_logger",
- "//tensorflow/contrib/copy_graph:copy_graph_py",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- "//tensorflow/tools/tfprof:protos_all_py",
- ],
-)
-
-py_library(
- name = "pprof_profiler",
- srcs = ["pprof_profiler.py"],
- srcs_version = "PY2AND3",
- deps = ["@com_google_pprof//:pprof_proto_py"],
-)
-
-py_test(
- name = "pprof_profiler_test",
- size = "small",
- srcs = ["pprof_profiler_test.py"],
- main = "pprof_profiler_test.py",
- srcs_version = "PY2AND3",
- tags = ["no_pip"], # TODO(annarev): get it working with pip.
- deps = [
- ":pprof_profiler",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform_test",
- "@com_google_pprof//:pprof_proto_py",
- ],
-)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/__init__.py b/tensorflow/contrib/tfprof/python/tools/tfprof/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/__init__.py
+++ /dev/null
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD
deleted file mode 100644
index a498574a95..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD
+++ /dev/null
@@ -1,80 +0,0 @@
-package(default_visibility = ["//tensorflow/contrib/tfprof/python/tools/tfprof:__subpackages__"])
-
-licenses(["notice"]) # Apache 2.0
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "py_test")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
-
-py_library(
- name = "model_analyzer_testlib",
- srcs = ["model_analyzer_testlib.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/rnn:rnn_py",
- "//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:rnn",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ],
-)
-
-py_test(
- name = "print_model_analysis_test",
- srcs = ["print_model_analysis_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:nn_ops",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:pywrap_tensorflow",
- "//tensorflow/python:variable_scope",
- "//tensorflow/tools/tfprof:protos_all_py",
- ],
-)
-
-cuda_py_test(
- name = "run_metadata_test",
- srcs = ["run_metadata_test.py"],
- additional_deps = [
- ":model_analyzer_testlib",
- "//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/tools/tfprof:protos_all_py",
- ],
- tags = [
- "no_pip",
- ],
-)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/model_analyzer_testlib.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/model_analyzer_testlib.py
deleted file mode 100644
index 42b83fde7c..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/model_analyzer_testlib.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""A test lib that defines some models."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import rnn
-from tensorflow.python.ops import rnn_cell
-from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import gradient_descent
-
-
-def BuildSmallModel():
- """Build a small forward conv model."""
- image = array_ops.zeros([2, 6, 6, 3])
- _ = variable_scope.get_variable(
- 'ScalarW', [],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- kernel = variable_scope.get_variable(
- 'DW', [3, 3, 3, 6],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
- kernel = variable_scope.get_variable(
- 'DW2', [2, 2, 6, 12],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
- return x
-
-
-def BuildFullModel():
- """Build the full model with conv,rnn,opt."""
- seq = []
- for i in range(4):
- with variable_scope.variable_scope('inp_%d' % i):
- seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
-
- cell = rnn_cell.BasicRNNCell(16)
- out = rnn.dynamic_rnn(
- cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
-
- target = array_ops.ones_like(out)
- loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
- sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
- return sgd_op.minimize(loss)
-
-
-def BuildSplitableModel():
- """Build a small model that can be run partially in each step."""
- image = array_ops.zeros([2, 6, 6, 3])
-
- kernel1 = variable_scope.get_variable(
- 'DW', [3, 3, 3, 6],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
-
- kernel2 = variable_scope.get_variable(
- 'DW2', [2, 3, 3, 6],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
-
- r3 = r1 + r2
- return r1, r2, r3
-
-
-def SearchTFProfNode(node, name):
- """Search a node in the tree."""
- if node.name == name:
- return node
- for c in node.children:
- r = SearchTFProfNode(c, name)
- if r: return r
- return None
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py
deleted file mode 100644
index 7ded5e890f..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py
+++ /dev/null
@@ -1,475 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""print_model_analysis test."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from google.protobuf import text_format
-
-from tensorflow.python.client import session
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-from tensorflow.tools.tfprof import tfprof_options_pb2
-from tensorflow.tools.tfprof import tfprof_output_pb2
-
-# pylint: disable=g-bad-import-order
-# XXX: this depends on pywrap_tensorflow and must come later
-from tensorflow.python import pywrap_tensorflow as print_mdl
-
-# pylint: disable=bad-whitespace
-# pylint: disable=bad-continuation
-TEST_OPTIONS = {
- 'max_depth': 10000,
- 'min_bytes': 0,
- 'min_micros': 0,
- 'min_params': 0,
- 'min_float_ops': 0,
- 'order_by': 'name',
- 'account_type_regexes': ['.*'],
- 'start_name_regexes': ['.*'],
- 'trim_name_regexes': [],
- 'show_name_regexes': ['.*'],
- 'hide_name_regexes': [],
- 'account_displayed_op_only': True,
- 'select': ['params'],
- 'output': 'stdout',
-}
-
-# pylint: enable=bad-whitespace
-# pylint: enable=bad-continuation
-
-
-class PrintModelAnalysisTest(test.TestCase):
-
- def _BuildSmallModel(self):
- image = array_ops.zeros([2, 6, 6, 3])
- kernel = variable_scope.get_variable(
- 'DW', [6, 6, 3, 6],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
- return x
-
- def testPrintModelAnalysis(self):
- opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = TEST_OPTIONS['max_depth']
- opts.min_bytes = TEST_OPTIONS['min_bytes']
- opts.min_micros = TEST_OPTIONS['min_micros']
- opts.min_params = TEST_OPTIONS['min_params']
- opts.min_float_ops = TEST_OPTIONS['min_float_ops']
- opts.order_by = TEST_OPTIONS['order_by']
- opts.step = -1
- for p in TEST_OPTIONS['account_type_regexes']:
- opts.account_type_regexes.append(p)
- for p in TEST_OPTIONS['start_name_regexes']:
- opts.start_name_regexes.append(p)
- for p in TEST_OPTIONS['trim_name_regexes']:
- opts.trim_name_regexes.append(p)
- for p in TEST_OPTIONS['show_name_regexes']:
- opts.show_name_regexes.append(p)
- for p in TEST_OPTIONS['hide_name_regexes']:
- opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
- for p in TEST_OPTIONS['select']:
- opts.select.append(p)
- opts.output = TEST_OPTIONS['output']
-
- with session.Session() as sess, ops.device('/cpu:0'):
- _ = self._BuildSmallModel()
- tfprof_pb = tfprof_output_pb2.TFGraphNodeProto()
- tfprof_pb.ParseFromString(
- print_mdl.PrintModelAnalysis(
- sess.graph.as_graph_def(add_shapes=True).SerializeToString(),
- b'',
- b'',
- b'scope',
- opts.SerializeToString()))
-
- expected_pb = tfprof_output_pb2.TFGraphNodeProto()
- text_format.Merge(r"""name: "_TFProfRoot"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 648
- children {
- name: "Conv2D"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 2
- }
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW"
- exec_micros: 0
- requested_bytes: 0
- parameters: 648
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 648
- children {
- name: "DW/Assign"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- children {
- name: "DW/Initializer/random_normal"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- children {
- name: "DW/Initializer/random_normal/RandomStandardNormal"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 4
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/mean"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/mul"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 1
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/shape"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/stddev"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 1
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 6
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 7
- }
- children {
- name: "DW/read"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 10
- }
- children {
- name: "zeros"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 13""", expected_pb)
- self.assertEqual(expected_pb, tfprof_pb)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py
deleted file mode 100644
index 8f3351999f..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""test the RunMetadata proto."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-
-import six
-
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-# pylint: disable=g-bad-import-order
-# XXX: this depends on pywrap_tensorflow and must come later
-from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
-from tensorflow.contrib.tfprof.python.tools.tfprof.internal import model_analyzer_testlib as lib
-SIZE = 1300
-
-
-def _extract_node(run_meta, node_name):
- ret = defaultdict(list)
- for dev_stat in run_meta.step_stats.dev_stats:
- dev = dev_stat.device
- for node_stat in dev_stat.node_stats:
- if node_stat.node_name == node_name:
- ret[dev].append(node_stat)
- return ret
-
-
-def _run_model():
- x = random_ops.random_normal(shape=[1, SIZE])
- w = random_ops.random_normal(shape=[SIZE, 2 * SIZE])
- y = math_ops.matmul(x, w)
-
- with session.Session() as sess:
- run_metadata = config_pb2.RunMetadata()
- opts = model_analyzer.PRINT_ALL_TIMING_MEMORY
- opts['min_micros'] = 0
- opts['min_bytes'] = 0
- _ = sess.run(y,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_metadata)
- tfprof_node = model_analyzer.print_model_analysis(
- sess.graph,
- run_meta=run_metadata,
- tfprof_options=opts)
-
- return tfprof_node, run_metadata
-
-
-def _run_loop_model():
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- tfprof_node = model_analyzer.print_model_analysis(
- sess.graph, run_meta,
- tfprof_options=model_analyzer.PRINT_ALL_TIMING_MEMORY)
- return tfprof_node, run_meta
-
-
-class RunMetadataTest(test.TestCase):
-
- def testGPU(self):
- if not test.is_gpu_available(cuda_only=True):
- return
-
- ops.reset_default_graph()
- with ops.device('/gpu:0'):
- tfprof_node, run_meta = _run_model()
- self.assertEqual(tfprof_node.children[0].name, 'MatMul')
- self.assertGreater(tfprof_node.children[0].exec_micros, 10)
-
- ret = _extract_node(run_meta, 'MatMul')
- self.assertEqual(len(ret), 1)
- self.assertTrue('/job:localhost/replica:0/task:0/gpu:0' in ret)
-
- ret = _extract_node(run_meta, 'MatMul:MatMul')
- self.assertEqual(len(ret), 2)
- has_all_stream = False
- for k, _ in six.iteritems(ret):
- self.assertTrue('gpu:0/stream' in k)
- if 'gpu:0/stream:all' in k:
- has_all_stream = True
- self.assertTrue(has_all_stream)
-
- def testCPU(self):
- ops.reset_default_graph()
- with ops.device('/cpu:0'):
- tfprof_node, run_meta = _run_model()
- self.assertEqual(tfprof_node.children[0].name, 'MatMul')
- self.assertGreater(tfprof_node.children[0].exec_micros, 0)
-
- ret = _extract_node(run_meta, 'MatMul')
- self.assertEqual(len(ret), 1)
- self.assertTrue('/job:localhost/replica:0/task:0/cpu:0' in ret)
-
- ret = _extract_node(run_meta, 'MatMul:MatMul')
- self.assertEqual(len(ret), 0)
-
- def testLoopCPU(self):
- ops.reset_default_graph()
- with ops.device('/cpu:0'):
- tfprof_node, run_meta = _run_loop_model()
- # The while-loop caused a node to appear 4 times in scheduling.
- ret = _extract_node(run_meta,
- 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul')
- self.assertEqual(len(ret['/job:localhost/replica:0/task:0/cpu:0']), 4)
-
- total_cpu_execs = 0
- for node in ret['/job:localhost/replica:0/task:0/cpu:0']:
- total_cpu_execs += node.op_end_rel_micros
-
- mm_node = lib.SearchTFProfNode(
- tfprof_node,
- 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul')
-
- self.assertEqual(mm_node.run_count, 4)
- self.assertEqual(mm_node.cpu_exec_micros, total_cpu_execs)
- self.assertEqual(mm_node.exec_micros, total_cpu_execs)
-
- # pylint: disable=pointless-string-statement
- """
- TODO(xpan): This test is flaky because RunMetadata returned from TensorFlow
- is random. Still being investigated.
- def testLoopGPU(self):
- if not test.is_gpu_available():
- return
-
- ops.reset_default_graph()
- with ops.device('/gpu:0'):
- tfprof_node, run_meta = _run_loop_model()
- # The while-loop caused a node to appear 4 times in scheduling.
- ret = _extract_node(run_meta,
- 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul')
- self.assertEqual(len(ret['/job:localhost/replica:0/task:0/gpu:0']), 4)
-
- total_cpu_execs = 0
- for node in ret['/job:localhost/replica:0/task:0/gpu:0']:
- total_cpu_execs += node.op_end_rel_micros
-
- ret = _extract_node(
- run_meta,
- 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul:MatMul')
- self.assertGreaterEqual(len(ret['/gpu:0/stream:all']), 4)
-
- total_accelerator_execs = 0
- for node in ret['/gpu:0/stream:all']:
- total_accelerator_execs += node.op_end_rel_micros
-
- mm_node = lib.SearchTFProfNode(
- tfprof_node,
- 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul')
-
- self.assertEqual(mm_node.run_count, 4)
- self.assertEqual(mm_node.accelerator_exec_micros, total_accelerator_execs)
- self.assertEqual(mm_node.cpu_exec_micros, total_cpu_execs)
- self.assertEqual(mm_node.exec_micros,
- total_cpu_execs + total_accelerator_execs)
- """
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
deleted file mode 100644
index 5b6111efdf..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ /dev/null
@@ -1,429 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Model Analyzer.
-
-Analyze model, including shape, params, time, memory, structure, etc.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import six
-
-from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
-from tensorflow.python import pywrap_tensorflow as print_mdl
-from tensorflow.python.framework import errors
-from tensorflow.tools.tfprof import tfprof_options_pb2
-from tensorflow.tools.tfprof import tfprof_output_pb2
-
-# pylint: disable=bad-whitespace
-# pylint: disable=bad-continuation
-# 2 example tfprof_options for print_model_analysis API.
-#
-# Show the parameter statistics of trainable variables.
-TRAINABLE_VARS_PARAMS_STAT_OPTIONS = {
- 'max_depth': 10000,
- 'min_bytes': 0,
- 'min_micros': 0,
- 'min_params': 0,
- 'min_float_ops': 0,
- 'order_by': 'name',
- 'account_type_regexes': [tfprof_logger.TRAINABLE_VARIABLES],
- 'start_name_regexes': ['.*'],
- 'trim_name_regexes': [],
- 'show_name_regexes': ['.*'],
- 'hide_name_regexes': [],
- 'account_displayed_op_only': True,
- 'select': ['params'],
- 'output': 'stdout',
- 'dump_to_file': ''
-}
-
-# Show the number float operations.
-FLOAT_OPS_OPTIONS = {
- 'max_depth': 10000,
- 'min_bytes': 0,
- 'min_micros': 0,
- 'min_params': 0,
- 'min_float_ops': 1,
- 'order_by': 'float_ops',
- 'account_type_regexes': ['.*'],
- 'start_name_regexes': ['.*'],
- 'trim_name_regexes': [],
- 'show_name_regexes': ['.*'],
- 'hide_name_regexes': [],
- 'account_displayed_op_only': True,
- 'select': ['float_ops'],
- 'output': 'stdout',
- 'dump_to_file': ''
-}
-
-# Show number of parameters on parameter server 0.
-# It is recommended to provide`run_meta` argument
-# to have complete device placement info.
-PRINT_PARAMS_ON_DEVICE = {
- 'max_depth': 1,
- 'min_bytes': 0,
- 'min_micros': 0,
- 'min_params': 0,
- 'min_float_ops': 0,
- 'order_by': 'name',
- 'account_type_regexes': ['.*ps.*task:0.*'],
- 'start_name_regexes': ['.*'],
- 'trim_name_regexes': [],
- 'show_name_regexes': ['.*'],
- 'hide_name_regexes': [],
- 'account_displayed_op_only': False,
- 'select': ['device', 'params'],
- 'output': 'stdout',
- 'dump_to_file': ''
-}
-
-# Show the timing stats and memory demands.
-PRINT_ALL_TIMING_MEMORY = {
- 'max_depth': 10000,
- 'min_bytes': 1, # Only >=1
- 'min_micros': 1, # Only >=1
- 'min_params': 0,
- 'min_float_ops': 0,
- 'order_by': 'name',
- 'account_type_regexes': ['.*'],
- 'start_name_regexes': ['.*'],
- 'trim_name_regexes': [],
- 'show_name_regexes': ['.*'],
- 'hide_name_regexes': [],
- 'account_displayed_op_only': True,
- 'select': ['micros', 'bytes'],
- 'output': 'stdout',
- 'dump_to_file': ''
-}
-
-# The following options are for 'advise' tfprof_cmd.
-# Show all advice.
-ALL_ADVICE = {
- 'ExpensiveOperationChecker': {},
- 'AcceleratorUtilizationChecker': {},
- 'JobChecker': {}, # Only available internally.
- 'OperationChecker': {},
-}
-
-# pylint: enable=bad-whitespace
-# pylint: enable=bad-continuation
-
-
-def _build_options(options):
- """Build tfprof.OptionsProto.
-
- Args:
- options: A dictionary of options.
- Returns:
- tfprof.OptionsProto.
- """
- opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = options.get('max_depth', 10)
- opts.min_bytes = options.get('min_bytes', 0)
- opts.min_micros = options.get('min_micros', 0)
- opts.min_params = options.get('min_params', 0)
- opts.min_float_ops = options.get('min_float_ops', 0)
- opts.min_occurrence = options.get('min_occurrence', 0)
-
- opts.step = options.get('step', -1)
-
- opts.order_by = options.get('order_by', 'name')
-
- for p in options.get('account_type_regexes', []):
- opts.account_type_regexes.append(p)
- for p in options.get('start_name_regexes', []):
- opts.start_name_regexes.append(p)
- for p in options.get('trim_name_regexes', []):
- opts.trim_name_regexes.append(p)
- for p in options.get('show_name_regexes', []):
- opts.show_name_regexes.append(p)
- for p in options.get('hide_name_regexes', []):
- opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = options.get('account_displayed_op_only',
- False)
-
- for p in options.get('select', []):
- opts.select.append(p)
-
- opts.output = options.get('output', 'stdout')
- opts.dump_to_file = options.get('dump_to_file', '')
-
- return opts
-
-
-def _build_advisor_options(options):
- """Build tfprof.AdvisorOptionsProto.
-
- Args:
- options: A dictionary of options. See ALL_ADVICE example.
- Returns:
- tfprof.AdvisorOptionsProto.
- """
- opts = tfprof_options_pb2.AdvisorOptionsProto()
- if options is None:
- return opts
- for checker, checker_opts in six.iteritems(options):
- checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption()
- for k, v in six.iteritems(checker_opts):
- checker_ops_pb[k] = v
- opts.checkers[checker].MergeFrom(checker_ops_pb)
- return opts
-
-
-class Profiler(object):
- """TensorFlow multi-step profiler.
-
- See go/tfprof or README for details.
-
- Typical use case:
- # Currently we are only allowed to create 1 profiler per process.
- profiler = Profile(sess.graph)
-
- for i in xrange(total_steps):
- if i % 10000 == 0:
- run_meta = tf.RunMetadata()
- _ = sess.run(...,
- options=tf.RunOptions(
- trace_level=tf.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
- profiler.add_step(i, run_meta)
-
- # Profile the parameters of your model.
- profiler.profile_name_scope(options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
-
- # Or profile the timing of your model operations.
- opts = PRINT_ALL_TIMING_MEMORY.copy()
- opts['order_by'] = 'micros'
- opts['select'] = ['micros', 'occurrence']
- opts['max_depth'] = 20
- profiler.profile_operations(options=opts)
-
- # Or you can generate a timeline:
- opts = PRINT_ALL_TIMING_MEMORY.copy()
- opts['output'] = 'timeline:outfile=' + filename
- opts['step'] = i
- profiler.profile_graph(options=opts)
- else:
- _ = sess.run(...)
- # Auto detect problems and generate advice.
- profiler.advise(model_analyzer.ALL_ADVICE)
- """
-
- def __init__(self, graph, op_log=None):
- """Constructor.
-
- Args:
- graph: tf.Graph.
- op_log: optional. tensorflow::tfprof::OpLog proto. Used to define
- extra op types.
- """
- self._graph = graph
- # pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(
- self._graph, op_log=op_log)
- # pylint: enable=protected-access
-
- print_mdl.NewProfiler(
- self._graph.as_graph_def(add_shapes=True).SerializeToString(),
- op_log.SerializeToString())
-
- def __del__(self):
- print_mdl.DeleteProfiler()
-
- def add_step(self, step, run_meta):
- """Add statistics of a step.
-
- Args:
- step: A step uint64 used to identify the RunMetadata. Must be different
- across different AddStep() calls.
- run_meta: RunMetadata proto that contains statistics of a session run.
- """
- # pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(
- self._graph, run_meta=run_meta, add_trace=False,
- add_trainable_var=False)
- # pylint: enable=protected-access
- print_mdl.AddStep(
- step, run_meta.SerializeToString(), op_log.SerializeToString())
-
- def profile_python_codes(self, options):
- """Profile the statistics of the Python codes.
-
- Hint: set options['show_name_regexes'] = ['.*my_code.py.*']
-
- Args:
- options: A dict of profiler options.
- Returns:
- a TFMultiGraphNodeProto that records the results.
- """
- opts = _build_options(options)
- tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
- return tfprof_node
-
- def profile_operations(self, options):
- """Profile the statistics of the Operation types (e.g. MatMul, Conv2D).
-
- Args:
- options: A dict of profiler options.
- Returns:
- a TFMultiGraphNodeProto that records the results.
- """
- opts = _build_options(options)
- tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
- return tfprof_node
-
- def profile_name_scope(self, options):
- """Profile the statistics of graph nodes, organized by name scope.
-
- Args:
- options: A dict of profiler options.
- Returns:
- a TFGraphNodeProto that records the results.
- """
- opts = _build_options(options)
- tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
- return tfprof_node
-
- def profile_graph(self, options):
- """Profile the statistics of graph nodes, organized by dataflow graph.
-
- Args:
- options: A dict of profiler options.
- Returns:
- a TFGraphNodeProto that records the results.
- """
- opts = _build_options(options)
- tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
- return tfprof_node
-
- def advise(self, options=ALL_ADVICE): # pylint: disable=dangerous-default-value
- """Automatically detect problems and generate reports.
-
- Args:
- options: A dict of options.
- Returns:
- A Advise proto that conains the reports from all checkers.
- """
- advise_pb = tfprof_output_pb2.AdviceProto()
- opts = _build_advisor_options(options)
- advise_pb.ParseFromString(
- print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString()))
- return advise_pb
-
-
-def print_model_analysis(graph,
- run_meta=None,
- op_log=None,
- tfprof_cmd='scope',
- tfprof_options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS):
- """Print model statistics.
-
- See go/tfprof or README for examples and tutorials.
- Run tfprof tool for help:
- 'bazel run third_party/tensorflow/tools/tfprof help'
-
- Args:
- graph: tf.Graph.
- run_meta: tensorflow::RunMetadata proto. When provided, also shows valid
- timing and memory information when 'select' option contains
- 'micros' and 'bytes'.
- op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
- group together ops and use a op_type to select the group.
- tfprof_cmd: string. Either 'op', 'scope', 'graph', 'code'.
- 'op' view organize outputs using operation type. (e.g. MatMul)
- 'scope' view organize outputs using graph node name scope.
- 'graph' view organize outputs using graph node inputs/outputs.
- 'code' view organize outputs using Python call stack.
- tfprof_options: See 'tfprof help' for details.
- Returns:
- If tfprof_cmd is 'scope' or 'graph', returns TFGraphNodeProto proto.
- If tfprof_cmd is 'op' or 'code', returns TFMultiGraphNodeProto proto.
- Side effect: stdout/file/timeline.json depending on tfprof_options['output']
- """
- # pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(
- graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
- # pylint: enable=protected-access
-
- opts = _build_options(tfprof_options)
-
- run_meta_str = run_meta.SerializeToString() if run_meta else b''
-
- if tfprof_cmd == 'code' or tfprof_cmd == 'op':
- tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.PrintModelAnalysis(
- graph.as_graph_def(add_shapes=True).SerializeToString(),
- run_meta_str,
- op_log.SerializeToString(),
- tfprof_cmd.encode('utf-8'),
- opts.SerializeToString()))
- elif tfprof_cmd == 'graph' or tfprof_cmd == 'scope':
- tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.PrintModelAnalysis(
- graph.as_graph_def(add_shapes=True).SerializeToString(),
- run_meta_str,
- op_log.SerializeToString(),
- tfprof_cmd.encode('utf-8'),
- opts.SerializeToString()))
- else:
- raise errors.InvalidArgumentError(
- None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd)
-
- return tfprof_node
-
-
-def advise(graph, run_meta=None, tfprof_options=ALL_ADVICE): # pylint: disable=dangerous-default-value
- """Auto profile and advise.
-
- Builds profiles and automatically check anormalies of various
- aspects. See go/tfprof or README for examples and tutorials.
-
- Args:
- graph: tf.Graph.
- run_meta: tensorflow::RunMetadata proto. Allows auto-profile
- time and memroy.
- tfprof_options: see ALL_ADVICE example above.
- Returns:
- Returns AdviceProto proto
- """
- # pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(
- graph, None, run_meta, add_trace=True)
- # pylint: enable=protected-access
-
- run_meta_str = run_meta.SerializeToString() if run_meta else b''
-
- opts = _build_advisor_options(tfprof_options)
- ret = tfprof_output_pb2.AdviceProto()
- ret.ParseFromString(
- print_mdl.PrintModelAnalysis(
- graph.as_graph_def(add_shapes=True).SerializeToString(),
- run_meta_str,
- op_log.SerializeToString(),
- 'advise'.encode('utf-8'),
- opts.SerializeToString()))
- return ret
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
deleted file mode 100644
index 32a6d5fdb2..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ /dev/null
@@ -1,319 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-# XXX: this depends on pywrap_tensorflow and must come later
-from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
-from tensorflow.contrib.tfprof.python.tools.tfprof.internal import model_analyzer_testlib as lib
-
-
-class PrintModelAnalysisTest(test.TestCase):
-
- def testDumpToFile(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
-
- with session.Session() as sess:
- _ = lib.BuildSmallModel()
- model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- self.assertEqual(u'node name | # parameters\n'
- '_TFProfRoot (--/451 params)\n'
- ' DW (3x3x3x6, 162/162 params)\n'
- ' DW2 (2x2x6x12, 288/288 params)\n'
- ' ScalarW (1, 1/1 params)\n',
- f.read())
-
- def testSelectEverything(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
- opts['account_type_regexes'] = ['.*']
- opts['select'] = [
- 'params', 'float_ops', 'occurrence', 'device', 'op_types',
- 'input_shapes'
- ]
-
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- config = config_pb2.ConfigProto(graph_options=graph_options)
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
- x = lib.BuildSmallModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- # pylint: disable=line-too-long
- self.assertEqual(
- 'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/10.44k flops, _kTFScopeParent, --/7|--/35, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/0 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, 1/1|1/1, )\n',
- f.read())
- # pylint: enable=line-too-long
-
- def testSimpleCodeView(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
- opts['account_type_regexes'] = ['.*']
- opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
- opts['account_displayed_op_only'] = False
- # TODO(xpan): Test 'micros'. Since the execution time changes each run,
- # it's a bit difficult to test it now.
- opts['select'] = [
- 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
- 'input_shapes'
- ]
-
- with session.Session() as sess:
- x = lib.BuildSmallModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- # pylint: disable=line-too-long
- self.assertEqual(
- 'node name | output bytes | # parameters | # float_ops | assigned devices | input',
- f.read()[0:80])
- # pylint: enable=line-too-long
-
- def testComplexCodeView(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
- opts['account_type_regexes'] = ['.*']
- opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
- opts['account_displayed_op_only'] = False
- opts['select'] = ['params', 'float_ops']
-
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- tfprof_node = model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
-
- # pylint: disable=line-too-long
- with gfile.Open(outfile, 'r') as f:
- lines = f.read().split('\n')
- result = '\n'.join([l[:min(len(l), 80)] for l in lines])
- self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_... (0/1.80k para\n model_analyzer_testlib.py:35:BuildSmallModel:image = array_ops... (0/0 param\n model_analyzer_testlib.py:39:BuildSmallModel:initializer=init_... (0/4 param\n model_analyzer_testlib.py:43:BuildSmallModel:initializer=init_... (0/648 par\n model_analyzer_testlib.py:44:BuildSmallModel:x = nn_ops.conv2d... (0/0 param\n model_analyzer_testlib.py:48:BuildSmallModel:initializer=init_... (0/1.15k p\n model_analyzer_testlib.py:49:BuildSmallModel:x = nn_ops.conv2d... (0/0 param\n model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c... (0/1.04k para\n model_analyzer_testlib.py:64:BuildFullModel:target = array_op... (0/0 params, \n model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_... (0/0 params, \n model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min... (0/0 params, \n',
- result)
-
- self.assertLess(0, tfprof_node.total_exec_micros)
- self.assertEqual(2844, tfprof_node.total_parameters)
- self.assertEqual(91040, tfprof_node.total_float_ops)
- self.assertEqual(5, len(tfprof_node.children))
- self.assertEqual('_TFProfRoot', tfprof_node.name)
- self.assertEqual(
- 'model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_...',
- tfprof_node.children[0].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c...',
- tfprof_node.children[1].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:64:BuildFullModel:target = array_op...',
- tfprof_node.children[2].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_...',
- tfprof_node.children[3].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min...',
- tfprof_node.children[4].name)
- # pylint: enable=line-too-long
-
- def testCodeViewLeafGraphNode(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['account_type_regexes'] = ['.*']
- opts['account_displayed_op_only'] = False
- opts['select'] = [
- 'bytes', 'params', 'float_ops', 'device'
- ]
- opts['output'] = 'none'
-
- with session.Session() as sess:
- x = lib.BuildSmallModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- tfprof_node = model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
-
- leaf = tfprof_node
- while leaf.children:
- self.assertEqual(0, len(leaf.graph_nodes))
- leaf = leaf.children[0]
- self.assertEqual(1, len(leaf.graph_nodes))
-
- def testTimeline(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'timeline')
- opts['output'] = 'timeline:outfile=' + outfile
- opts['account_type_regexes'] = ['.*']
- opts['max_depth'] = 100000
- opts['step'] = 0
-
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(
- x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- _ = model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_cmd='graph', tfprof_options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- # Test that a json file is created.
- # TODO(xpan): tfprof Timeline isn't quite correct on Windows.
- # Investigate why.
- if os.name != 'nt':
- self.assertLess(1000, len(f.read()))
- else:
- self.assertLess(1, len(f.read()))
-
- def testOpView(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
- opts['account_type_regexes'] = ['.*']
- opts['min_occurrence'] = 10
- opts['select'] = ['params', 'micros', 'occurrence', 'input_shapes']
- opts['order_by'] = 'occurrence'
-
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- tfprof_node = model_analyzer.print_model_analysis(
- sess.graph, run_meta, tfprof_cmd='op', tfprof_options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- # pylint: disable=line-too-long
- self.assertEqual(
- 'nodename|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes\n',
- f.read().replace('\t', '').replace(' ', '')[0:120])
- # pylint: enable=line-too-long
-
- total_children = 0
- last_occurrence = 1e32
- input_shapes = 0
- last_total_micros = tfprof_node.total_exec_micros
- last_micros = tfprof_node.exec_micros
- while tfprof_node.children:
- for gnode in tfprof_node.graph_nodes:
- input_shapes += len(gnode.input_shapes)
- self.assertEqual(len(tfprof_node.children), 1)
- tfprof_node = tfprof_node.children[0]
-
- self.assertEqual(
- last_total_micros, tfprof_node.total_exec_micros + last_micros)
- last_total_micros = tfprof_node.total_exec_micros
- last_micros = tfprof_node.exec_micros
-
- total_children += 1
- self.assertLessEqual(len(tfprof_node.graph_nodes), last_occurrence)
- last_occurrence = len(tfprof_node.graph_nodes)
-
- self.assertEqual(total_children, 15)
- self.assertGreater(input_shapes, 0)
-
- def testAdvisor(self):
- ops.reset_default_graph()
-
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(
- x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- advice_pb = model_analyzer.advise(sess.graph, run_meta)
- self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
- self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
- self.assertTrue('OperationChecker' in advice_pb.checkers)
-
- checker = advice_pb.checkers['AcceleratorUtilizationChecker']
- if test.is_gpu_available():
- self.assertGreater(len(checker.reports), 0)
- else:
- self.assertEqual(len(checker.reports), 0)
- checker = advice_pb.checkers['ExpensiveOperationChecker']
- self.assertGreater(len(checker.reports), 0)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
deleted file mode 100644
index c57e45748d..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
+++ /dev/null
@@ -1,445 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Profiler for TensorFlow models that outputs data in pprof format.
-
-See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
-profile format.
-The following needs to be set for profiler to work:
- * trace_level needs to be set to FULL_TRACE
- * run_metadata object should be passed in to session.run call
-
-Sample usage:
- options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
- run_metadata = tf.RunMetadata()
-
- with tf.Session as sess:
- ...
- sess.run(computation, run_metadata=run_metadata, options=options)
- pprof_profiler.profile(sess.graph, run_metadata, output_dir)
-
-
- The code above would output a pprof profile to separate output_dir/.*.pb.gz
- file for each device. These files can be passed to pprof for formatting.
- For e.g.:
- pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from collections import namedtuple
-import gzip
-import os
-import string
-import sys
-import time
-
-from proto import profile_pb2
-
-
-if sys.version_info < (3,):
- maketrans = string.maketrans
-else:
- maketrans = str.maketrans
-
-
-ProfileDatum = namedtuple('ProfileDatum', [
- 'node_exec_stats', 'op_type', 'traceback'])
-
-
-class StringTable(object):
- """Keeps track of strings to add to string_table in pprof proto."""
-
- def __init__(self):
- # Pprof requires first entry in string_table to be ''.
- self._string_table = ['']
- self._string_to_index = {'': 0}
-
- def index_of(self, value_str):
- """Get index of value_str in the string table.
-
- If value_str is not in the string table, we will add it at the end
- and then return the new index.
- Args:
- value_str: (string) Value to lookup/add in/to the string table.
-
- Returns:
- Index of value_str in the string table.
- """
- if value_str is None:
- value_str = ''
- if value_str in self._string_to_index:
- return self._string_to_index[value_str]
- index = len(self._string_table)
- self._string_table.append(value_str)
- self._string_to_index[value_str] = index
- return index
-
- def next_index(self):
- """Gets index that would be assigned to the next added string.
-
- Returns:
- Index of the next string if it was added.
- """
- return len(self._string_table)
-
- def string_table(self):
- """Returns a list of strings to store in pprof's string_table."""
- return self._string_table
-
-
-class Functions(object):
- """Keeps track of `Function` protos for pprof profile."""
-
- def __init__(self, string_table):
- """Constructor.
-
- Args:
- string_table: A `StringTable` object.
- """
- self._string_table = string_table
- # Maps tuples in the form (file_path, function_name, start_line_number)
- # to `Function` protos.
- self._function_key_to_function = {}
-
- def index_of(self, file_path, function_name, function_start_line):
- """Returns index of the function, adding the function if needed.
-
- Args:
- file_path: (string) Path to file where the function is defined.
- function_name: (string) Function name.
- function_start_line: (integer) Start line number of function definition.
-
- Returns:
- Function index.
- """
- function_key = (file_path, function_name, function_start_line)
- if function_key in self._function_key_to_function:
- return self._function_key_to_function[function_key].id
- else:
- # Function indexes should start from 1
- function_index = len(self._function_key_to_function) + 1
- function = profile_pb2.Function()
- function.id = function_index
- function.name = self._string_table.index_of(function_name)
- function.filename = self._string_table.index_of(file_path)
- function.start_line = function_start_line
- self._function_key_to_function[function_key] = function
- return function_index
-
- def function_protos(self):
- """Returns list of `profile_pb2.Function` protos."""
- return self._function_key_to_function.values()
-
-
-class Locations(object):
- """Keeps track of `Location` protos for pprof profile.
-
- `Locations` store information about function call locations.
- """
-
- def __init__(self, functions):
- """Constructor.
-
- Args:
- functions: A `Functions` object.
- """
- self._functions = functions
- # Maps tuples in the form (file_path, called_function_name, line_number)
- # to `Location` protos.
- self._location_key_to_location = {}
-
- def index_of(
- self, file_path, line_number, called_function_name, called_file_path,
- called_function_start_line):
- """Returns index of the location, adding the location if needed.
-
- Args:
- file_path: (string) Path to file that makes the call.
- line_number: (integer) Call line number.
- called_function_name: (string) Function name of the function called at
- `file_path` and `line_number`.
- called_file_path: (string) Path to file where the called function is
- defined.
- called_function_start_line: (integer) Start line number of called
- function definition in `called_file_path` file.
-
- Returns:
- Index of location.
- """
- location_key = (file_path, called_function_name, line_number)
- if location_key in self._location_key_to_location:
- location = self._location_key_to_location[location_key]
- return location.id
- else:
- # Location indexes should start from 1
- location_index = len(self._location_key_to_location) + 1
- location = profile_pb2.Location()
- location.id = location_index
- self._location_key_to_location[location_key] = location
-
- line = location.line.add()
- line.function_id = self._functions.index_of(
- called_file_path, called_function_name, called_function_start_line)
- line.line = line_number
- return location_index
-
- def location_protos(self):
- """Returns list of `profile_pb2.Location` protos."""
- return self._location_key_to_location.values()
-
-
-class Samples(object):
- """Keeps track of `Sample` protos for pprof profile.
-
- Samples store the following statistics in order:
- count, all_time, op_time
- """
-
- def __init__(self, string_table):
- """Constructor.
-
- Args:
- string_table: A `StringTable` object.
- """
- self._string_table = string_table
- # TODO(annarev): figure out if location is unique for each node name.
- # If not, also key this dictionary based on location ids.
- self._node_name_to_sample = {}
-
- def add(self, datum, location_ids):
- """Adds a sample data point.
-
- Args:
- datum: `ProfileDatum` to add a sample for.
- location_ids: List of numberic location ids for this
- sample.
- """
- node_name = datum.node_exec_stats.node_name
- if node_name in self._node_name_to_sample:
- sample = self._node_name_to_sample[node_name]
- sample.location_id.extend(location_ids)
- else:
- sample = profile_pb2.Sample()
- # Sample stores 3 values: count, all_time, op_time
- sample.value.extend([0, 0, 0])
-
- label = sample.label.add()
- label.key = self._string_table.index_of('node_name')
- label.str = self._string_table.index_of(node_name)
- label = sample.label.add()
- label.key = self._string_table.index_of('op_type')
- label.str = self._string_table.index_of(datum.op_type)
- self._node_name_to_sample[node_name] = sample
- sample.value[0] += 1
- sample.value[1] += datum.node_exec_stats.all_end_rel_micros
- sample.value[2] += (
- datum.node_exec_stats.op_end_rel_micros -
- datum.node_exec_stats.op_start_rel_micros)
-
- def get_sample_protos(self):
- """Returns list of `Sample` protos for pprof profile."""
- return self._node_name_to_sample.values()
-
-
-class PprofProfiler(object):
- """Creates profiles in pprof format."""
-
- def __init__(self, graph, run_metadata):
- """Constructor.
-
- Args:
- graph: A `Graph` instance.
- run_metadata: A list of `RunMetadata` objects.
- """
- self._graph = graph
- self._run_metadata = run_metadata
- self._string_table = StringTable()
- self._functions = Functions(self._string_table)
- self._locations = Locations(self._functions)
-
- def profile(self):
- """Generates pprof profiles.
-
- Returns:
- Dictionary mapping from device name to proto in `profile_pb2.Profile`
- format.
- """
- profiles = {}
- data_generator_func = self._get_profile_data_generator()
- for device_index, device_stats in enumerate(
- self._run_metadata.step_stats.dev_stats):
- # Create profile
- pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
- if not pprof_proto.sample:
- print(
- 'Not enough data to create profile for device %s. Did you pass '
- 'RunMetadata to session.run call?' % device_stats.device)
- continue
- # Add device name comment
- device_count = len(self._run_metadata.step_stats.dev_stats)
- device_description = (
- 'Device %d of %d: %s' %
- (device_index + 1, device_count, device_stats.device))
- device_description_str_index = self._string_table.next_index()
- pprof_proto.string_table.append(device_description)
- pprof_proto.comment.append(device_description_str_index)
- profiles[device_stats.device] = pprof_proto
- return profiles
-
- def _get_pprof_proto(self, profile_datum_generator):
- """Returns profile data in pprof proto format.
-
- Args:
- profile_datum_generator: Generator outputting `ProfileDatum` objects.
-
- Returns:
- A proto in pprof format.
- """
- pprof_profile = profile_pb2.Profile()
- samples = Samples(self._string_table)
-
- for datum in profile_datum_generator:
- if not datum.traceback:
- continue
-
- stack_frame = datum.traceback[-1]
- after_apply_op = False
- location_ids = []
-
- # We add locations from stack trace in bottom-up order.
- for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
- prev_stack_frame = stack_frame
- stack_frame = datum.traceback[stack_frame_index]
-
- # Call at current frame calls function at previous frame.
- prev_file_path = prev_stack_frame[0]
- prev_function = prev_stack_frame[2]
- prev_function_start_line = prev_stack_frame[4]
- curr_file_path = stack_frame[0]
- curr_line_number = stack_frame[1]
-
- # Skip all calls up to apply_op since they are the same for all ops.
- if not after_apply_op:
- if prev_function == 'apply_op':
- after_apply_op = True
- continue
- location_index = self._locations.index_of(
- curr_file_path, curr_line_number,
- prev_function, prev_file_path, prev_function_start_line)
- location_ids.append(location_index)
- samples.add(datum, location_ids)
-
- sample_type_description = 'count'
- sample_type = pprof_profile.sample_type.add()
- sample_type.type = self._string_table.index_of(sample_type_description)
- sample_type.unit = self._string_table.index_of('count')
- sample_type_description = 'all_time'
- sample_type = pprof_profile.sample_type.add()
- sample_type.type = self._string_table.index_of(sample_type_description)
- sample_type.unit = self._string_table.index_of('nanoseconds')
- sample_type_description = 'op_time'
- sample_type = pprof_profile.sample_type.add()
- sample_type.type = self._string_table.index_of(sample_type_description)
- sample_type.unit = self._string_table.index_of('nanoseconds')
-
- pprof_profile.string_table.extend(self._string_table.string_table())
- pprof_profile.sample.extend(samples.get_sample_protos())
- pprof_profile.function.extend(self._functions.function_protos())
- pprof_profile.location.extend(self._locations.location_protos())
- return pprof_profile
-
- def _get_profile_data_generator(self):
- """Get function that generates `ProfileDatum` objects.
-
- Returns:
- A function that generates `ProfileDatum` objects.
- """
- node_to_traceback = defaultdict(list)
- node_to_op_type = defaultdict(str)
- for op in self._graph.get_operations():
- node_to_traceback[op.name] = op.traceback_with_start_lines
- node_to_op_type[op.name] = op.type
-
- def profile_data_generator(device_step_stats):
- for node_stats in device_step_stats.node_stats:
- if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
- continue
- yield ProfileDatum(
- node_stats,
- node_to_op_type[node_stats.node_name],
- node_to_traceback[node_stats.node_name])
-
- return profile_data_generator
-
-
-def get_profiles(graph, run_metadata):
- """Generate profiles in pprof format.
-
- See https://github.com/google/pprof/blob/master/proto/profile.proto
- for pprof proto format.
-
- Args:
- graph: A `Graph` object.
- run_metadata: A `RunMetadata` proto.
-
- Returns:
- A dictionary mapping from device name to pprof proto for that device.
- """
- return PprofProfiler(graph, run_metadata).profile()
-
-
-def profile(graph, run_metadata, output_dir=None):
- """Generate profiles in pprof format.
-
- See https://github.com/google/pprof/blob/master/proto/profile.proto
- for pprof proto format.
-
- Args:
- graph: A `Graph` object.
- run_metadata: A `RunMetadata` proto.
- output_dir: (string) Directory to output pprof profile to.
- Profile files for each device will be stored in compressed
- serialized proto format. If output_dir is None, profile protos
- will be printed to stdout instead.
-
- Returns:
- List of output files created by this profile call.
- (Note: this list will be empty if output_dir is None)
- """
- profiles = get_profiles(graph, run_metadata)
- output_file_template = None
- if output_dir:
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir)
- time_suffix = time.strftime('%Y%m%d%H%M%S')
- output_file_template = os.path.join(
- output_dir, '%s_' + time_suffix + '.pb.gz')
-
- profile_files = []
- for device, pprof_proto in profiles.items():
- if output_file_template is None:
- print('No output directory specified, printing to stdout instead.')
- print(pprof_proto)
- else:
- device_name = str(device).strip('/').translate(
- maketrans('/:', '__'))
- profile_file = output_file_template % device_name
- profile_files.append(profile_file)
- with gzip.open(profile_file, 'w') as output_file:
- print('Writing profile to %s...' % profile_file)
- output_file.write(pprof_proto.SerializeToString())
- return profile_files
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
deleted file mode 100644
index 6487adf992..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for pprof_profiler."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-
-from proto import profile_pb2
-from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
-from tensorflow.core.framework import step_stats_pb2
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.framework import constant_op
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class PprofProfilerTest(test.TestCase):
-
- def testDataEmpty(self):
- output_dir = test.get_temp_dir()
- run_metadata = config_pb2.RunMetadata()
- graph = test.mock.MagicMock()
- graph.get_operations.return_value = []
-
- profiles = pprof_profiler.get_profiles(graph, run_metadata)
- self.assertEquals(0, len(profiles))
- profile_files = pprof_profiler.profile(
- graph, run_metadata, output_dir)
- self.assertEquals(0, len(profile_files))
-
- def testRunMetadataEmpty(self):
- output_dir = test.get_temp_dir()
- run_metadata = config_pb2.RunMetadata()
- graph = test.mock.MagicMock()
- op1 = test.mock.MagicMock()
- op1.name = 'Add/123'
- op1.traceback = [('a/b/file1', 10, 'some_var')]
- op1.type = 'add'
- graph.get_operations.return_value = [op1]
-
- profiles = pprof_profiler.get_profiles(graph, run_metadata)
- self.assertEquals(0, len(profiles))
- profile_files = pprof_profiler.profile(
- graph, run_metadata, output_dir)
- self.assertEquals(0, len(profile_files))
-
- def testValidProfile(self):
- output_dir = test.get_temp_dir()
- run_metadata = config_pb2.RunMetadata()
-
- node1 = step_stats_pb2.NodeExecStats(
- node_name='Add/123',
- op_start_rel_micros=3,
- op_end_rel_micros=5,
- all_end_rel_micros=4)
-
- run_metadata = config_pb2.RunMetadata()
- device1 = run_metadata.step_stats.dev_stats.add()
- device1.device = 'deviceA'
- device1.node_stats.extend([node1])
-
- graph = test.mock.MagicMock()
- op1 = test.mock.MagicMock()
- op1.name = 'Add/123'
- op1.traceback = [
- ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
- op1.type = 'add'
- graph.get_operations.return_value = [op1]
-
- expected_proto = """sample_type {
- type: 5
- unit: 5
-}
-sample_type {
- type: 6
- unit: 7
-}
-sample_type {
- type: 8
- unit: 7
-}
-sample {
- value: 1
- value: 4
- value: 2
- label {
- key: 1
- str: 2
- }
- label {
- key: 3
- str: 4
- }
-}
-string_table: ""
-string_table: "node_name"
-string_table: "Add/123"
-string_table: "op_type"
-string_table: "add"
-string_table: "count"
-string_table: "all_time"
-string_table: "nanoseconds"
-string_table: "op_time"
-string_table: "Device 1 of 1: deviceA"
-comment: 9
-"""
- # Test with protos
- profiles = pprof_profiler.get_profiles(graph, run_metadata)
- self.assertEquals(1, len(profiles))
- self.assertTrue('deviceA' in profiles)
- self.assertEquals(expected_proto, str(profiles['deviceA']))
- # Test with files
- profile_files = pprof_profiler.profile(
- graph, run_metadata, output_dir)
- self.assertEquals(1, len(profile_files))
- with gzip.open(profile_files[0]) as profile_file:
- profile_contents = profile_file.read()
- profile = profile_pb2.Profile()
- profile.ParseFromString(profile_contents)
- self.assertEquals(expected_proto, str(profile))
-
- def testProfileWithWhileLoop(self):
- options = config_pb2.RunOptions()
- options.trace_level = config_pb2.RunOptions.FULL_TRACE
- run_metadata = config_pb2.RunMetadata()
-
- num_iters = 5
- with self.test_session() as sess:
- i = constant_op.constant(0)
- c = lambda i: math_ops.less(i, num_iters)
- b = lambda i: math_ops.add(i, 1)
- r = control_flow_ops.while_loop(c, b, [i])
- sess.run(r, options=options, run_metadata=run_metadata)
- profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
- self.assertEquals(1, len(profiles))
- profile = next(iter(profiles.values()))
- add_samples = [] # Samples for the while/Add node
- for sample in profile.sample:
- if profile.string_table[sample.label[0].str] == 'while/Add':
- add_samples.append(sample)
- # Values for same nodes are aggregated.
- self.assertEquals(1, len(add_samples))
- # Value of "count" should be equal to number of iterations.
- self.assertEquals(num_iters, add_samples[0].value[0])
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
deleted file mode 100644
index dd25340564..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-# pylint: disable=g-bad-import-order
-from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
-from tensorflow.contrib.tfprof.python.tools.tfprof.internal import model_analyzer_testlib as lib
-
-
-class ProfilerTest(test.TestCase):
-
- def testProfileBasic(self):
- ops.reset_default_graph()
- opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['account_type_regexes'] = ['.*']
- opts['select'] = ['params', 'float_ops', 'micros', 'bytes',
- 'device', 'op_types', 'occurrence']
- outfile = os.path.join(test.get_temp_dir(), 'dump')
- opts['output'] = 'file:outfile=' + outfile
-
- # Test the output without run_meta.
- sess = session.Session()
- r = lib.BuildFullModel()
- sess.run(variables.global_variables_initializer())
-
- profiler = model_analyzer.Profiler(sess.graph)
- profiler.profile_name_scope(opts)
- with gfile.Open(outfile, 'r') as f:
- profiler_str = f.read()
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='scope', tfprof_options=opts)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertEqual(pma_str, profiler_str)
-
- # Test the output with run_meta.
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(r,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- profiler.add_step(1, run_meta)
- profiler.profile_graph(opts)
- with gfile.Open(outfile, 'r') as f:
- profiler_str = f.read()
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='graph', run_meta=run_meta, tfprof_options=opts)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertEqual(pma_str, profiler_str)
-
- profiler.profile_python_codes(opts)
- with gfile.Open(outfile, 'r') as f:
- profiler_str = f.read()
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='code', run_meta=run_meta, tfprof_options=opts)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertEqual(pma_str, profiler_str)
-
- profiler.profile_operations(opts)
- with gfile.Open(outfile, 'r') as f:
- profiler_str = f.read()
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='op', run_meta=run_meta, tfprof_options=opts)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertEqual(pma_str, profiler_str)
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertNotEqual(pma_str, profiler_str)
-
- opts2 = opts.copy()
- opts2['select'] = ['params', 'float_ops']
- profiler.profile_name_scope(opts2)
- with gfile.Open(outfile, 'r') as f:
- profiler_str = f.read()
-
- model_analyzer.print_model_analysis(
- sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts2)
- with gfile.Open(outfile, 'r') as f:
- pma_str = f.read()
- self.assertEqual(pma_str, profiler_str)
-
- def testMultiStepProfile(self):
- ops.reset_default_graph()
- opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
- opts['account_type_regexes'] = ['.*']
-
- with session.Session() as sess:
- r1, r2, r3 = lib.BuildSplitableModel()
- sess.run(variables.global_variables_initializer())
-
- profiler = model_analyzer.Profiler(sess.graph)
- pb0 = profiler.profile_name_scope(opts)
-
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(r1,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
- profiler.add_step(1, run_meta)
- pb1 = profiler.profile_name_scope(opts)
-
- self.assertNotEqual(lib.SearchTFProfNode(pb1, 'DW'), None)
- self.assertEqual(lib.SearchTFProfNode(pb1, 'DW2'), None)
- self.assertEqual(lib.SearchTFProfNode(pb1, 'add'), None)
-
- run_meta2 = config_pb2.RunMetadata()
- _ = sess.run(r2,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta2)
- profiler.add_step(2, run_meta2)
- pb2 = profiler.profile_name_scope(opts)
-
- self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW'), None)
- self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW2'), None)
- self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
-
- run_meta3 = config_pb2.RunMetadata()
- _ = sess.run(r3,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta3)
- profiler.add_step(3, run_meta3)
- pb3 = profiler.profile_name_scope(opts)
-
- self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW'), None)
- self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW2'), None)
- self.assertNotEqual(lib.SearchTFProfNode(pb3, 'add'), None)
-
- self.assertEqual(lib.SearchTFProfNode(pb0, 'Conv2D'), None)
- self.assertGreater(lib.SearchTFProfNode(pb1, 'Conv2D').exec_micros, 0)
- self.assertEqual(lib.SearchTFProfNode(pb1, 'Conv2D_1'), None)
- self.assertGreater(lib.SearchTFProfNode(pb2, 'Conv2D_1').exec_micros, 0)
- self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
- self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
-
- advice_pb = profiler.advise(model_analyzer.ALL_ADVICE)
- self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
- self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
- self.assertTrue('OperationChecker' in advice_pb.checkers)
-
- checker = advice_pb.checkers['AcceleratorUtilizationChecker']
- if test.is_gpu_available():
- self.assertGreater(len(checker.reports), 0)
- else:
- self.assertEqual(len(checker.reports), 0)
- checker = advice_pb.checkers['ExpensiveOperationChecker']
- self.assertGreater(len(checker.reports), 0)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
deleted file mode 100644
index 52febef26c..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ /dev/null
@@ -1,187 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Logging tensorflow::tfprof::OpLog.
-
-OpLog is used to add extra model information for offline analysis by tfprof.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import sys
-
-import six
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.platform import gfile
-from tensorflow.tools.tfprof import tfprof_log_pb2
-
-TRAINABLE_VARIABLES = '_trainable_variables'
-REGISTERED_FLOP_STATS = 'flops'
-
-
-def _fill_missing_graph_shape(graph, run_meta):
- """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
- for dev_stat in run_meta.step_stats.dev_stats:
- for node_stat in dev_stat.node_stats:
- if not node_stat.output:
- continue
- try:
- op = graph.get_operation_by_name(node_stat.node_name)
- except KeyError as e:
- # Graph doesn't contains the node_stat, usually RecvTensor.
- continue
- if len(node_stat.output) != len(op.outputs):
- # For example, conditional op has only 1 output at run time.
- continue
- for (i, node_stat_out) in enumerate(node_stat.output):
- if op.outputs[i].get_shape().is_fully_defined():
- continue
- node_stat_dims = node_stat_out.tensor_description.shape.dim
- node_stat_shape = tensor_shape.TensorShape(
- [d.size for d in node_stat_dims])
- try:
- op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
- node_stat_shape))
- except ValueError as e:
- sys.stderr.write('Node %s incompatible shapes: %s.\n' %
- (node_stat.node_name, e))
- return graph
-
-
-def _get_logged_ops(graph, run_meta=None, add_trace=True,
- add_trainable_var=True):
- """Extract trainable model parameters and FLOPs for ops from a Graph.
-
- Args:
- graph: tf.Graph.
- run_meta: RunMetadata proto used to complete shape information.
- add_trace: Whether to add op trace information.
- add_trainable_var: Whether to assign tf.trainable_variables() op type
- '_trainable_variables'.
- Returns:
- logged_ops: dict mapping from op_name to OpLogEntry.
- """
- if run_meta:
- graph = _fill_missing_graph_shape(graph, run_meta)
-
- op_missing_shape = 0
- logged_ops = {}
- # TODO(xpan): Work with Profiler more efficiently.
- for op in graph.get_operations():
- try:
- stats = ops.get_stats_for_node_def(
- graph, op.node_def, REGISTERED_FLOP_STATS)
- except ValueError:
- # Catch Exception When shape is incomplete. Skip it.
- op_missing_shape += 1
- stats = None
-
- entry = tfprof_log_pb2.OpLogEntry()
- entry.name = op.name
- add_entry = False
- if stats and stats.value:
- entry.float_ops = int(stats.value)
- add_entry = True
-
- if add_trace:
- for tb in op.traceback:
- trace = entry.code_def.traces.add()
- trace.file = tb[0] if tb[0] else 'none'
- trace.lineno = tb[1] if tb[1] else -1
- trace.function = tb[2] if tb[2] else 'none'
- trace.line = tb[3] if tb[3] else 'none'
- add_entry = True
-
- if add_entry:
- logged_ops[entry.name] = entry
-
- if add_trainable_var:
- for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
- if v.op.name not in logged_ops:
- entry = tfprof_log_pb2.OpLogEntry()
- entry.name = v.op.name
- entry.types.append(TRAINABLE_VARIABLES)
- logged_ops[entry.name] = entry
- else:
- logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
-
- if op_missing_shape > 0 and not run_meta:
- sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
- op_missing_shape)
- return logged_ops
-
-
-def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
- add_trace=True, add_trainable_var=True):
- """Merge the tfprof default extra info with caller's op_log.
-
- Args:
- graph: tf.Graph.
- op_log: OpLog proto.
- run_meta: RunMetadata proto used to complete shape information.
- add_trace: Whether to add op trace information.
- add_trainable_var: Whether to assign tf.trainable_variables() op type
- '_trainable_variables'.
- Returns:
- tmp_op_log: Merged OpLog proto.
- """
- tmp_op_log = tfprof_log_pb2.OpLog()
- logged_ops = _get_logged_ops(
- graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
-
- if not op_log:
- tmp_op_log.log_entries.extend(logged_ops.values())
- else:
- all_ops = dict()
- for entry in op_log.log_entries:
- all_ops[entry.name] = entry
- for op_name, entry in six.iteritems(logged_ops):
- if op_name in all_ops:
- all_ops[op_name].types.extend(entry.types)
- if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
- all_ops[op_name].float_ops = entry.float_ops
- if entry.code_def.traces and not all_ops[op_name].code_def.traces:
- all_ops[op_name].code_def.MergeFrom(entry.code_def)
- else:
- all_ops[op_name] = entry
- tmp_op_log.log_entries.extend(all_ops.values())
- return tmp_op_log
-
-
-def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
- """Log provided 'op_log', and add additional model information below.
-
- The API also assigns ops in tf.trainable_variables() an op type called
- '_trainable_variables'.
- The API also logs 'flops' statistics for ops with op.RegisterStatistics()
- defined. flops calculation depends on Tensor shapes defined in 'graph',
- which might not be complete, 'run_meta', if provided, completes the shape
- information with best effort.
-
- Args:
- graph: tf.Graph.
- log_dir: directory to write the log file.
- op_log: (Optional) OpLog proto to be written. If not provided, an new
- one is created.
- run_meta: (Optional) RunMetadata proto that helps flops computation using
- run time shape information.
- add_trace: Whether to add op trace information. Used to support "code" view.
- """
- op_log = _merge_default_with_oplog(graph, op_log, run_meta, add_trace)
-
- with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
- log.write(op_log.SerializeToString())
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py
deleted file mode 100644
index 87dfdc0fc1..0000000000
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.copy_graph.python.util import copy_elements
-from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class TFProfLoggerTest(test.TestCase):
-
- def _BuildSmallPlaceholderlModel(self):
- a = array_ops.placeholder(dtypes.int32, [2, 2])
- b = array_ops.placeholder(dtypes.int32, [2, 2])
- y = math_ops.matmul(a, b)
- return a, b, y
-
- def _BuildSmallModel(self):
- a = constant_op.constant([[1, 2], [3, 4]])
- b = constant_op.constant([[1, 2], [3, 4]])
- return math_ops.matmul(a, b)
-
- def testFillMissingShape(self):
- a, b, y = self._BuildSmallPlaceholderlModel()
- run_options = config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE)
- run_metadata = config_pb2.RunMetadata()
- sess = session.Session()
- sess.run(y,
- options=run_options,
- run_metadata=run_metadata,
- feed_dict={a: [[1, 2], [2, 3]],
- b: [[1, 2], [2, 3]]})
-
- graph2 = ops.Graph()
- # Use copy_op_to_graph to remove shape information.
- y2 = copy_elements.copy_op_to_graph(y, graph2, [])
- self.assertEquals('<unknown>', str(y2.get_shape()))
-
- tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
- self.assertEquals('(2, 2)', str(y2.get_shape()))
-
- def testFailedFillMissingShape(self):
- y = self._BuildSmallModel()
- run_options = config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE)
- run_metadata = config_pb2.RunMetadata()
- sess = session.Session()
- sess.run(y, options=run_options, run_metadata=run_metadata)
-
- graph2 = ops.Graph()
- y2 = copy_elements.copy_op_to_graph(y, graph2, [])
- self.assertEquals('<unknown>', str(y2.get_shape()))
- # run_metadata has special name for MatMul, hence failed to fill shape.
- tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
- self.assertEquals('<unknown>', str(y2.get_shape()))
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/tfprof_logger.py
new file mode 100644
index 0000000000..9588bd2985
--- /dev/null
+++ b/tensorflow/contrib/tfprof/tfprof_logger.py
@@ -0,0 +1,24 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Logging tensorflow::tfprof::OpLog.
+
+OpLog is used to add extra model information for offline analysis by tfprof.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from tensorflow.python.profiler.tfprof_logger import write_op_log