aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py7
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py8
-rw-r--r--tensorflow/python/debug/lib/stepper_test.py20
3 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 9f95074022..32a6d5fdb2 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
@@ -61,7 +62,11 @@ class PrintModelAnalysisTest(test.TestCase):
'input_shapes'
]
- with session.Session() as sess, ops.device('/cpu:0'):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 06e1228b95..ee8cabca0d 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -22,6 +22,8 @@ import re
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import stepper_cli
from tensorflow.python.debug.lib import stepper
@@ -143,7 +145,11 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
self.e, name="opt")
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config)
self.sess.run(self.a.initializer)
self.sess.run(self.b.initializer)
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 78e7b3b5eb..686fb45238 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib.stepper import NodeStepper
from tensorflow.python.framework import constant_op
@@ -52,7 +54,11 @@ class StepperTest(test_util.TensorFlowTestCase):
self.z = math_ops.multiply(self.x, self.y, name="z") # Should be -4.0.
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config)
self.sess.run(variables.global_variables_initializer())
def tearDown(self):
@@ -581,7 +587,11 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
1.0,
name="v_add_plus_one")
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config)
self.sess.run(self.v.initializer)
def tearDown(self):
@@ -708,7 +718,11 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
gradient_descent.GradientDescentOptimizer(0.01).minimize(
self.f, name="optim")
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config)
self.sess.run(variables.global_variables_initializer())
def tearDown(self):