aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/optimizer_v2_test.py')
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
index a44bfd1bfd..dd7f2f4405 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
@@ -61,7 +61,7 @@ class OptimizerTest(test.TestCase):
def testAggregationMethod(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -86,7 +86,7 @@ class OptimizerTest(test.TestCase):
def testPrecomputedGradient(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -212,7 +212,7 @@ class OptimizerTest(test.TestCase):
sgd_op.apply_gradients(grads_and_vars)
def testTrainOp(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([3.0, 4.0])
cost = 5 * var0 + 3 * var1
@@ -225,7 +225,7 @@ class OptimizerTest(test.TestCase):
def testConstraint(self):
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0],
constraint=constraint_01)
var1 = variables.Variable([3.0, 4.0],
@@ -247,7 +247,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([0., 0.], var1.eval())
def testStopGradients(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], name='var0')
var1 = variables.Variable([3.0, 4.0], name='var1')
var0_id = array_ops.identity(var0)