aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 12:54:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 12:57:46 -0700
commitf3c89936e97c99dead1ca3310246691c1b221adf (patch)
tree3c99b66936ed59028b32609115a239f52798907d /tensorflow/contrib/tfprof
parent0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff)
Merge changes from github.
END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit 1699d904a authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit 03da61134 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commit e89f04d80 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit 15a8df724 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commit b58d98353 authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commit a286b7db8 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/README.md23
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i1
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py2
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py121
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py30
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py16
6 files changed, 53 insertions, 140 deletions
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index 824ba4c09b..4fa1ccea69 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -1,3 +1,26 @@
# tfprof: TensorFlow Profiler and Beyond
# Full Document in tensorflow/tools/tfprof/README.md
+
+Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang
+
+Consultants: Jon Shlens, Pete Warden
+
+###Major Features
+
+1. Measure model parameters, float operations, tensor shapes.
+2. Profile op execution times, requested memory size and device placement.
+3. Inspect checkpoint tensors' shapes and their values.
+4. Selectively group, filter, account and order ops.
+
+####tfprof supports 3 views to organize TensorFlow model profiles
+
+ * code view: Stats are associated your Python codes and organized as call stacks.
+ * scope view: Stats are organized as name scope hierarchies.
+ * graph view: Stats are organized as Tensorflow Op graph.
+
+####For each view, there are 3 ways to display outputs:
+
+ * stdout: Results are written to stdout.
+ * timeline: Visualized in chrome browser as time series.
+ * file: Results are dumped to file.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
index 582c36e339..40f29ae8a2 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
@@ -43,6 +43,7 @@ using tensorflow::int64;
%unignore tensorflow::tfprof::DeleteProfiler;
%unignore tensorflow::tfprof::AddStep;
%unignore tensorflow::tfprof::Profile;
+%unignore tensorflow::tfprof::Advise;
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py
index 9c59df3117..71468dde37 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py
@@ -89,7 +89,7 @@ def _run_loop_model():
class RunMetadataTest(test.TestCase):
def testGPU(self):
- if not test.is_gpu_available():
+ if not test.is_gpu_available(cuda_only=True):
return
ops.reset_default_graph()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index c781d2af4e..419beac0b9 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -20,8 +20,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
-
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
from tensorflow.contrib.tfprof.python.tools.tfprof.internal import pywrap_tensorflow_print_model_analysis_lib as print_mdl
from tensorflow.python.framework import errors
@@ -110,77 +108,49 @@ PRINT_ALL_TIMING_MEMORY = {
'dump_to_file': ''
}
-# The following options are for 'advise' tfprof_cmd.
-# Show all advice.
-ALL_ADVICE = {
- 'ExpensiveOperationChecker': {},
- 'AcceleratorUtilizationChecker': {},
- 'JobChecker': {}, # Only available internally.
- 'OperationChecker': {},
-}
-
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
-def _build_options(options):
+def _build_options(tfprof_options):
"""Build tfprof.OptionsProto.
Args:
- options: A dictionary of options.
+ tfprof_options: A dictionary of options.
Returns:
tfprof.OptionsProto.
"""
opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = options.get('max_depth', 10)
- opts.min_bytes = options.get('min_bytes', 0)
- opts.min_micros = options.get('min_micros', 0)
- opts.min_params = options.get('min_params', 0)
- opts.min_float_ops = options.get('min_float_ops', 0)
- opts.min_occurrence = options.get('min_occurrence', 0)
+ opts.max_depth = tfprof_options.get('max_depth', 10)
+ opts.min_bytes = tfprof_options.get('min_bytes', 0)
+ opts.min_micros = tfprof_options.get('min_micros', 0)
+ opts.min_params = tfprof_options.get('min_params', 0)
+ opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
+ opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
- opts.step = options.get('step', -1)
+ opts.step = tfprof_options.get('step', -1)
- opts.order_by = options.get('order_by', 'name')
+ opts.order_by = tfprof_options.get('order_by', 'name')
- for p in options.get('account_type_regexes', []):
+ for p in tfprof_options.get('account_type_regexes', []):
opts.account_type_regexes.append(p)
- for p in options.get('start_name_regexes', []):
+ for p in tfprof_options.get('start_name_regexes', []):
opts.start_name_regexes.append(p)
- for p in options.get('trim_name_regexes', []):
+ for p in tfprof_options.get('trim_name_regexes', []):
opts.trim_name_regexes.append(p)
- for p in options.get('show_name_regexes', []):
+ for p in tfprof_options.get('show_name_regexes', []):
opts.show_name_regexes.append(p)
- for p in options.get('hide_name_regexes', []):
+ for p in tfprof_options.get('hide_name_regexes', []):
opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = options.get('account_displayed_op_only',
- False)
+ opts.account_displayed_op_only = tfprof_options.get(
+ 'account_displayed_op_only', False)
- for p in options.get('select', []):
+ for p in tfprof_options.get('select', []):
opts.select.append(p)
- opts.output = options.get('output', 'stdout')
- opts.dump_to_file = options.get('dump_to_file', '')
-
- return opts
-
-
-def _build_advisor_options(options):
- """Build tfprof.AdvisorOptionsProto.
+ opts.output = tfprof_options.get('output', 'stdout')
+ opts.dump_to_file = tfprof_options.get('dump_to_file', '')
- Args:
- options: A dictionary of options. See ALL_ADVICE example.
- Returns:
- tfprof.AdvisorOptionsProto.
- """
- opts = tfprof_options_pb2.AdvisorOptionsProto()
- if options is None:
- return opts
- for checker, checker_opts in six.iteritems(options):
- checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption()
- for k, v in six.iteritems(checker_opts):
- checker_ops_pb[k] = v
- opts.checkers[checker].MergeFrom(checker_ops_pb)
return opts
@@ -220,7 +190,7 @@ class Profiler(object):
else:
_ = sess.run(...)
# Auto detect problems and generate advice.
- profiler.advise(model_analyzer.ALL_ADVICE)
+ profiler.advise()
"""
def __init__(self, graph, op_log=None):
@@ -318,19 +288,9 @@ class Profiler(object):
print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
- def advise(self, options=ALL_ADVICE): # pylint: disable=dangerous-default-value
- """Automatically detect problems and generate reports.
-
- Args:
- options: A dict of options.
- Returns:
- A Advise proto that conains the reports from all checkers.
- """
- advise_pb = tfprof_output_pb2.AdviceProto()
- opts = _build_advisor_options(options)
- advise_pb.ParseFromString(
- print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString()))
- return advise_pb
+ def advise(self):
+ """Automatically detect problems and generate reports."""
+ print_mdl.Advise()
def print_model_analysis(graph,
@@ -394,36 +354,3 @@ def print_model_analysis(graph,
None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd)
return tfprof_node
-
-
-def advise(graph, run_meta=None, tfprof_options=ALL_ADVICE): # pylint: disable=dangerous-default-value
- """Auto profile and advise.
-
- Builds profiles and automatically check anormalies of various
- aspects. See go/tfprof or README for examples and tutorials.
-
- Args:
- graph: tf.Graph.
- run_meta: tensorflow::RunMetadata proto. Allows auto-profile
- time and memroy.
- tfprof_options: see ALL_ADVICE example above.
- Returns:
- Returns AdviceProto proto
- """
- # pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(
- graph, None, run_meta, add_trace=True)
- # pylint: enable=protected-access
-
- run_meta_str = run_meta.SerializeToString() if run_meta else b''
-
- opts = _build_advisor_options(tfprof_options)
- ret = tfprof_output_pb2.AdviceProto()
- ret.ParseFromString(
- print_mdl.PrintModelAnalysis(
- graph.as_graph_def(add_shapes=True).SerializeToString(),
- run_meta_str,
- op_log.SerializeToString(),
- 'advise'.encode('utf-8'),
- opts.SerializeToString()))
- return ret
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index fea27a82a5..9db752c577 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -126,7 +126,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts['account_displayed_op_only'] = False
opts['select'] = ['params', 'float_ops']
- with session.Session() as sess:
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -176,7 +176,6 @@ class PrintModelAnalysisTest(test.TestCase):
opts['select'] = [
'bytes', 'params', 'float_ops', 'device'
]
- opts['output'] = 'none'
with session.Session() as sess:
x = lib.BuildSmallModel()
@@ -277,33 +276,6 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertEqual(total_children, 15)
self.assertGreater(input_shapes, 0)
- def testAdvisor(self):
- ops.reset_default_graph()
-
- with session.Session() as sess:
- x = lib.BuildFullModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(
- x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- advice_pb = model_analyzer.advise(sess.graph, run_meta)
- self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
- self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
- self.assertTrue('OperationChecker' in advice_pb.checkers)
-
- checker = advice_pb.checkers['AcceleratorUtilizationChecker']
- if test.is_gpu_available():
- self.assertGreater(len(checker.reports), 0)
- else:
- self.assertEqual(len(checker.reports), 0)
- checker = advice_pb.checkers['ExpensiveOperationChecker']
- self.assertGreater(len(checker.reports), 0)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
index c7113b6a57..5daaafd7c8 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
@@ -129,7 +129,7 @@ class ProfilerTest(test.TestCase):
opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
opts['account_type_regexes'] = ['.*']
- with session.Session() as sess:
+ with session.Session() as sess, ops.device('/cpu:0'):
r1, r2, r3 = lib.BuildSplitableModel()
sess.run(variables.global_variables_initializer())
@@ -179,18 +179,8 @@ class ProfilerTest(test.TestCase):
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
- advice_pb = profiler.advise(model_analyzer.ALL_ADVICE)
- self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
- self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
- self.assertTrue('OperationChecker' in advice_pb.checkers)
-
- checker = advice_pb.checkers['AcceleratorUtilizationChecker']
- if test.is_gpu_available():
- self.assertGreater(len(checker.reports), 0)
- else:
- self.assertEqual(len(checker.reports), 0)
- checker = advice_pb.checkers['ExpensiveOperationChecker']
- self.assertGreater(len(checker.reports), 0)
+ # TODO(xpan): Better test of advisor.
+ profiler.advise()
if __name__ == '__main__':