aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 18:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 18:25:59 -0700
commit708b30f4cb82271bb28cb70a1e0c89a1933f5b64 (patch)
tree22470a9314f7f4225b6d08170a3d7ea91b0216a1 /tensorflow/contrib/rnn
parentd0cac47a767dd972516f75ce57f0d6185e3b6514 (diff)
Move from deprecated self.test_session() to self.session() when a graph is set.
self.test_session() has been deprecated in cl/208545396 as its behavior confuses readers of the test. Moving to self.session() instead. PiperOrigin-RevId: 209696110
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py26
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py4
3 files changed, 17 insertions, 17 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 85f0f8ced9..15ce9d1ce7 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -225,7 +225,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2], dtype=dtype)
@@ -395,7 +395,7 @@ class RNNCellTest(test.TestCase):
def testIndyLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2], dtype=dtype)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index d62ec45d18..aa4562be7c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -457,7 +457,7 @@ class LSTMTest(test.TestCase):
input_size = 5
batch_size = 2
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, num_units)
@@ -491,7 +491,7 @@ class LSTMTest(test.TestCase):
input_size = 5
batch_size = 2
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(
@@ -588,7 +588,7 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
inputs = max_length * [
@@ -834,7 +834,7 @@ class LSTMTest(test.TestCase):
batch_size = 2
num_proj = 4
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
initializer_d = init_ops.random_uniform_initializer(
-1, 1, seed=self._seed + 1)
@@ -884,7 +884,7 @@ class LSTMTest(test.TestCase):
batch_size = 2
num_proj = 4
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
@@ -930,7 +930,7 @@ class LSTMTest(test.TestCase):
max_length = 8
sequence_length = [4, 6]
in_graph_mode = not context.executing_eagerly()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
if in_graph_mode:
@@ -1006,7 +1006,7 @@ class LSTMTest(test.TestCase):
max_length = 8
sequence_length = [4, 6]
in_graph_mode = not context.executing_eagerly()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
if in_graph_mode:
@@ -1612,7 +1612,7 @@ class MultiDimensionalLSTMTest(test.TestCase):
batch_size = 2
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(None,) + input_size)
]
@@ -1723,7 +1723,7 @@ class NestedLSTMTest(test.TestCase):
state_size = 6
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
state_saver = TestStateSaver(batch_size, state_size)
single_input = (array_ops.placeholder(
dtypes.float32, shape=(None, input_size)),
@@ -2017,7 +2017,7 @@ class RawRNNTest(test.TestCase):
np.random.seed(self._seed)
def _testRawRNN(self, max_time):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
batch_size = 16
input_depth = 4
num_units = 3
@@ -2126,7 +2126,7 @@ class RawRNNTest(test.TestCase):
self._testRawRNN(max_time=10)
def testLoopState(self):
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
max_time = 10
batch_size = 16
input_depth = 4
@@ -2162,7 +2162,7 @@ class RawRNNTest(test.TestCase):
self.assertEqual([10], loop_state.eval())
def testLoopStateWithTensorArray(self):
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
max_time = 4
batch_size = 16
input_depth = 4
@@ -2205,7 +2205,7 @@ class RawRNNTest(test.TestCase):
self.assertAllEqual([1, 2, 2 + 2, 4 + 3, 7 + 4], loop_state.eval())
def testEmitDifferentStructureThanCellOutput(self):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
max_time = 10
batch_size = 16
input_depth = 4
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index c7d85862f6..2df8f0ec05 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1440,7 +1440,7 @@ class CompiledWrapperTest(test.TestCase):
atol = 1e-5
random_seed.set_random_seed(1234)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
xla_ops = _create_multi_lstm_cell_ops(
batch_size=batch_size,
num_units=num_units,
@@ -1452,7 +1452,7 @@ class CompiledWrapperTest(test.TestCase):
xla_results = sess.run(xla_ops)
random_seed.set_random_seed(1234)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
non_xla_ops = _create_multi_lstm_cell_ops(
batch_size=batch_size,
num_units=num_units,