aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 9f95074022..32a6d5fdb2 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
@@ -61,7 +62,11 @@ class PrintModelAnalysisTest(test.TestCase):
'input_shapes'
]
- with session.Session() as sess, ops.device('/cpu:0'):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())