aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/momentum_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/momentum_test.py')
-rw-r--r--tensorflow/python/training/momentum_test.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index f7e78071d8..8a21c39d32 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -123,7 +123,7 @@ class MomentumOptimizerTest(test.TestCase):
]), self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -162,7 +162,7 @@ class MomentumOptimizerTest(test.TestCase):
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -188,7 +188,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparseNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -282,7 +282,7 @@ class MomentumOptimizerTest(test.TestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -435,7 +435,7 @@ class MomentumOptimizerTest(test.TestCase):
return db_grad, db_out
def testLikeDistBeliefMom01(self):
- with self.test_session():
+ with self.cached_session():
db_grad, db_out = self._dbParamsMom01()
num_samples = len(db_grad)
var0 = variables.Variable([0.0] * num_samples)
@@ -449,7 +449,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
grads0 = ops.IndexedSlices(
@@ -518,7 +518,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)