aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-23 17:21:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 17:25:26 -0700
commit1f8db608007ae60f89bf38c4c6af98a0248f214e (patch)
tree69f31cad90804663dad9ae9e44e8a13d02dd22a8 /tensorflow/python/profiler
parent2862f65fd6e6966ebf8af7cb4fa754b319202b0f (diff)
Add blacklist ops to PinToHostOptimizer. Fix test.
PiperOrigin-RevId: 214195020
Diffstat (limited to 'tensorflow/python/profiler')
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py42
1 files changed, 23 insertions, 19 deletions
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index c0e16ca536..94c685274a 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -52,13 +52,19 @@ builder = option_builder.ProfileOptionBuilder
class PrintModelAnalysisTest(test.TestCase):
+ def _no_rewrite_session_config(self):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
def testDumpToFile(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = builder(builder.trainable_variables_parameter()
).with_file_output(outfile).build()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
_ = lib.BuildSmallModel()
model_analyzer.profile(sess.graph, options=opts)
@@ -83,7 +89,8 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess, ops.device(dev):
+ with session.Session(
+ config=self._no_rewrite_session_config()) as sess, ops.device(dev):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -149,11 +156,8 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['params', 'float_ops', 'occurrence', 'device', 'op_types',
'input_shapes']).build())
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- config = config_pb2.ConfigProto(graph_options=graph_options)
- with session.Session(config=config) as sess, ops.device('/device:CPU:0'):
+ with session.Session(config=self._no_rewrite_session_config()
+ ) as sess, ops.device('/device:CPU:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -179,7 +183,7 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -213,7 +217,7 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -274,7 +278,7 @@ class PrintModelAnalysisTest(test.TestCase):
.account_displayed_op_only(False)
.select(['bytes', 'params', 'float_ops', 'device']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -302,7 +306,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_timeline_output(outfile)
.with_accounted_types(['.*']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -338,7 +342,7 @@ class PrintModelAnalysisTest(test.TestCase):
'peak_bytes', 'residual_bytes',
'output_bytes', 'occurrence', 'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -384,7 +388,7 @@ class PrintModelAnalysisTest(test.TestCase):
def testAdvisor(self):
ops.reset_default_graph()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -417,7 +421,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_node_names(trim_name_regexes=['ops.py.*'])
.with_pprof_output(outfile).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -484,7 +488,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertGreaterEqual(n.output_bytes, mob)
check_min(n.children, mm, mam, mcm, mb, mpb, mrb, mob)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -549,7 +553,7 @@ class PrintModelAnalysisTest(test.TestCase):
for attr in not_selected:
self.assertFalse(s.find(attr) > 0, s)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -582,7 +586,7 @@ class PrintModelAnalysisTest(test.TestCase):
def _trainLoop(self, train_op, train_steps, time_dir, time_step,
memory_dir, memory_step, profile_dir, dump_step):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(variables.global_variables_initializer())
# start from 1 because variable_initializer took one step.
for i in range(1, train_steps + 1):
@@ -655,7 +659,7 @@ class PrintModelAnalysisTest(test.TestCase):
c = a * b
try:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(c, options=config_pb2.RunOptions(
report_tensor_allocations_upon_oom=True))
except Exception as e: # pylint: disable=broad-except
@@ -758,7 +762,7 @@ class PrintModelAnalysisTest(test.TestCase):
grad = gradients.gradients(y, [x1])
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()