diff options
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/optimizer_v2_test.py')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2_test.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index a44bfd1bfd..dd7f2f4405 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -61,7 +61,7 @@ class OptimizerTest(test.TestCase): def testAggregationMethod(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) cost = 5 * var0 + 3 * var1 @@ -86,7 +86,7 @@ class OptimizerTest(test.TestCase): def testPrecomputedGradient(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) cost = 5 * var0 + 3 * var1 @@ -212,7 +212,7 @@ class OptimizerTest(test.TestCase): sgd_op.apply_gradients(grads_and_vars) def testTrainOp(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([3.0, 4.0]) cost = 5 * var0 + 3 * var1 @@ -225,7 +225,7 @@ class OptimizerTest(test.TestCase): def testConstraint(self): constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], constraint=constraint_01) var1 = variables.Variable([3.0, 4.0], @@ -247,7 +247,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([0., 0.], var1.eval()) def testStopGradients(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], name='var0') var1 = variables.Variable([3.0, 4.0], name='var1') var0_id = array_ops.identity(var0) |