aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/adam_test.py6
-rw-r--r--tensorflow/compiler/tests/reshape_op_test.py2
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py2
-rw-r--r--tensorflow/contrib/autograph/utils/misc_test.py4
-rw-r--r--tensorflow/contrib/autograph/utils/py_func_test.py8
-rw-r--r--tensorflow/contrib/autograph/utils/tensor_list_test.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py18
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py56
-rw-r--r--tensorflow/python/eager/function_test.py28
-rw-r--r--tensorflow/python/eager/graph_only_ops_test.py4
-rw-r--r--tensorflow/python/eager/tape_test.py4
-rw-r--r--tensorflow/python/keras/layers/gru_test.py8
-rw-r--r--tensorflow/python/keras/layers/lstm_test.py22
-rw-r--r--tensorflow/python/keras/layers/simplernn_test.py8
20 files changed, 112 insertions, 112 deletions
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index df0f21471a..058576b3d4 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -98,7 +98,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -140,7 +140,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py
index 84c6777940..96e0b07475 100644
--- a/tensorflow/compiler/tests/reshape_op_test.py
+++ b/tensorflow/compiler/tests/reshape_op_test.py
@@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
('64_bit_index', dtypes.int64))
def testBasic(self, index_dtype):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
shape = constant_op.constant([3, 2], dtype=index_dtype)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 3f928a1bea..0f3843dc1e 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -34,7 +34,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected,
equality_fn=None):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py
index 71e358c33e..968ea03df6 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/contrib/autograph/utils/misc_test.py
@@ -31,7 +31,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
def test_alias_tensors(self):
@@ -46,7 +46,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_v is v)
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py
index 2468263142..f60b57bcce 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/contrib/autograph/utils/py_func_test.py
@@ -31,7 +31,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c):
return a + b + c
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, sess.run(result))
@@ -52,7 +52,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b):
return a * b.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
@@ -69,7 +69,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c, d):
return a * b.foo + c * d.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
'c': 11,
'd': TestClass(13)
@@ -89,7 +89,7 @@ class PyFuncTest(test.TestCase):
def test_fn(_):
side_counter[0] += 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual([1], side_counter)
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py
index d58489eb68..faaf7b7877 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py
@@ -42,18 +42,18 @@ class TensorListTest(test.TestCase):
l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(l[0]), 1)
def test_list_append_python(self):
@@ -107,7 +107,7 @@ class TensorListTest(test.TestCase):
l0 = l[0]
l[0] = b
l1 = l[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l0, l1, a, b = sess.run([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 5e07b9313f..284a4f45f6 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase):
def test_unsupervised(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
inp, _ = feeder.input_builder()
feed_dict_fn = feeder.get_feed_dict_fn()
feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase):
def test_epoch(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
feeder.input_builder()
epoch = feeder.make_epoch_variable()
feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b7d9..5e90d1fa20 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase):
for index in range(2):
yield {'a': np.ones(1) * index}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase):
'label2': np.ones(1) * index - 64,
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key=['label', 'label2'],
@@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones((3, 3)) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase):
def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
x = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
failing_input_fn = generator_io.generator_input_fn(
x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
return np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
yield np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase):
}
y = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', np.arange(10)]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', 'target']
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
features = input_fn()
@@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e8f3..396539a76a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index a2d82cf800..553b116a3b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -30,7 +30,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTable(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -53,7 +53,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTableVectors(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
@@ -79,7 +79,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
output.eval())
def testExportSharded(self):
- with self.test_session():
+ with self.cached_session():
empty_key = -2
default_val = -1
num_shards = 2
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
index 237a6812b7..51c4f68543 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -36,13 +36,13 @@ class SparseFeatureColumnTest(TensorFlowTestCase):
self.assertTrue(isinstance(sfc.example_indices, ops.Tensor))
self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor))
self.assertEqual(sfc.feature_values, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
expected_feature_values = [1.0, 2.0, 3.0, 4.0]
sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
expected_feature_values)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index aa4562be7c..bf699db3ed 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase):
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e41e..8d34b9e852 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase):
def testBasicRNNFusedWrapper(self):
"""This test checks that using a wrapper for BasicRNN works as expected."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase):
self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testTimeReversedFusedRNN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 2df8f0ec05..6689664fb9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.util import nest
class RNNCellTest(test.TestCase):
def testCoupledInputForgetGateLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
state_size = num_units * 2
batch_size = 3
@@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_state)
def testTimeFreqLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
state_size = num_units * 2
batch_size = 3
@@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
@@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase):
.state_f00_b00_c[i, :]))) > 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
feature_size = 2
@@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase):
]],
dtype=np.float32)
for state_is_tuple in [False, True]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCellWithSliceOffset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase):
batch_size = 3
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase):
0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase):
0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase):
[[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase):
[[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 1]
filter_size = [3]
num_features = 1
@@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 1]
filter_size = [3, 3]
num_features = 1
@@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 2, 1]
filter_size = [3, 3, 3]
num_features = 1
@@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase):
# Try with input dimension equal to num_units or not.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root1_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase):
# Try with num_inputs equal to or not equal to num_units.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root2_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase):
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
def testBasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_h, 1e-5)
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
num_units = 5
allowed_low = [1, 2, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(1)):
x = array_ops.zeros([1, 5])
@@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase):
def _cell_output(self, cell):
"""Calculates cell output."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 37a9957cea..92254a2c00 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -104,7 +104,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(step(), 2.0)
def testGraphGradientVariable(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
@@ -211,7 +211,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(f(), x)
def testSymGradGatherNd(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
@function.defun
def f(x):
@@ -481,7 +481,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
def testGraphModeCaptureVariable(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
class HasAVar(object):
@@ -509,12 +509,12 @@ class FunctionTest(test.TestCase):
x = constant_op.constant(1.0)
l = f(x, v)
_, dv = gradients_impl.gradients(l, [x, v])
- with self.test_session():
+ with self.cached_session():
v.initializer.run()
self.assertAllEqual(dv.eval(), 0.0)
def testGraphModeManyFunctions(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
@function.defun
def f(x):
@@ -934,7 +934,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@function.defun
@@ -1497,7 +1497,7 @@ class FunctionTest(test.TestCase):
class AutomaticControlDependenciesTest(test.TestCase):
def testBasic(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
with function.AutomaticControlDependencies() as c:
@@ -1508,7 +1508,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(), 4.0)
def testCondMustRun(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1529,7 +1529,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
def testCondMustRunSeparateRead(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1552,7 +1552,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(v.read_value().eval(), 6.0)
def testCondNested(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1586,7 +1586,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
def testCondOneBranch(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1606,7 +1606,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
def testCondOneBranchUpdateBefore(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1627,7 +1627,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
def testCondOneBranchUpdateAfter(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1663,7 +1663,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(out, [3, 4, 5])
def testDecorator(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
index d2a2b4e223..3cf3a61a62 100644
--- a/tensorflow/python/eager/graph_only_ops_test.py
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
def testGraphZerosLike(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
z_tf = graph_only_ops.graph_zeros_like(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
def testGraphPlaceholder(self):
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
y_tf = math_ops.square(x_tf)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = np.array([42])
y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
self.assertAllClose(np.square(x), y)
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 4326d5efa3..acd0e569f1 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -72,7 +72,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_c = tf_a + tf_b
@@ -135,7 +135,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_mm = math_ops.matmul(tf_a, tf_b)
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index afef997b00..9988c9fae5 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -87,7 +87,7 @@ class GRULayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class GRULayerTest(test.TestCase):
def test_regularizers_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class GRULayerTest(test.TestCase):
def test_constraints_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class GRULayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_GRU(self):
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 9802820fd0..f536915324 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -102,7 +102,7 @@ class LSTMLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -161,7 +161,7 @@ class LSTMLayerTest(test.TestCase):
def test_regularizers_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -180,7 +180,7 @@ class LSTMLayerTest(test.TestCase):
def test_constraints_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -200,7 +200,7 @@ class LSTMLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_LSTM(self):
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
@@ -225,7 +225,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -252,7 +252,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with non-Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.backend.random_normal_variable(
@@ -275,7 +275,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LSTM(units, stateful=True)
layer.build((num_samples, timesteps, embedding_dim))
layer.reset_states()
@@ -306,7 +306,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((timesteps, embedding_dim))
_ = keras.layers.Masking()(inputs)
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -329,7 +329,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, stateful=True)
outputs = layer(inputs)
@@ -347,7 +347,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, return_sequences=True)
outputs = layer(inputs)
@@ -366,7 +366,7 @@ class LSTMLayerTest(test.TestCase):
num_states = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
main_inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 1429537648..2f2295a793 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -87,7 +87,7 @@ class SimpleRNNLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_regularizers_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_constraints_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class SimpleRNNLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_SimpleRNN(self):
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)