aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py29
-rw-r--r--tensorflow/contrib/compiler/jit_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py4
-rw-r--r--tensorflow/contrib/gan/python/train_test.py4
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py28
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py8
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py38
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py12
-rw-r--r--tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py2
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py10
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py2
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops_test.py32
12 files changed, 86 insertions, 85 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814546..01ee8703a9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase):
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase):
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase):
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase):
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
@@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase):
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 42b3b9f026..3e631b5909 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,7 +173,7 @@ class JITTest(test.TestCase):
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[3.]])
y_nc = math_ops.matmul(x, x, name="not_compiled")
with jit.experimental_jit_scope():
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 8dad80aa64..c32ea9ade7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -93,12 +93,12 @@ class SoftsignBijectorTest(test.TestCase):
bijector.inverse_log_det_jacobian(y, event_ndims=1)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.linspace(-0.99, 0.99, 100).astype(np.float32)
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58f348034f..64d6706199 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -399,7 +399,7 @@ class StarGANModelTest(test.TestCase):
target_tensor = train._generate_stargan_random_domain_target(
batch_size, domain_numbers)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
targets = sess.run(target_tensor)
self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
for target in targets:
@@ -676,7 +676,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(model_loss, namedtuples.GANLoss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a771cc..27aed091c2 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase):
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid2BasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid1LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid3LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGridRNNEdgeCasesLikeRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase):
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape(), (3, num_units))
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase):
def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False)."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 9ed017592a..f44edaa14c 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class InputPipelineOpsTest(test.TestCase):
def testObtainNext(self):
- with self.test_session():
+ with self.cached_session():
var = state_ops.variable_op([], dtypes.int64)
state_ops.assign(var, -1).op.run()
c = constant_op.constant(["a", "b"])
@@ -45,7 +45,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNext(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
session.run([variables.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
@@ -65,7 +65,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochs(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
session.run([
variables.local_variables_initializer(),
@@ -75,7 +75,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochsThree(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
session.run([
variables.local_variables_initializer(),
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 72507539f8..4d5cc24ce0 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLogitsShape(self):
"""An error is raised when logits have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2,))
labels = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsShape(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(1, 1, 2))
with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidWeightsShape(self):
"""An error is raised when weights have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2,))
weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsDtype(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], dtype=dtypes.float32)
with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
"""An error is raised when weights are None."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0])
with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesSameRank(self):
"""Error raised when weights and labels have same ranks, different sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
"""Error raised when weights and labels have different ranks and sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2, 1))
weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testOutOfRangeLabels(self):
"""An error is raised when labels are not in [0, num_classes)."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt32Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt64Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
[-0.5, 0.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
]
for batch_size, num_classes in logits_shapes:
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(
dtypes.float32, shape=(batch_size, num_classes))
labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
[1.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictions(self):
"""Loss is >0 when an incorrect class has higher logits than true class."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
[0.5, -1.8, 2.0]])
labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsColumnLabels(self):
"""Same as above but labels is a rank-2 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsZeroWeights(self):
"""Loss is 0 when all weights are missing even if predictions are wrong."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithPythonScalarWeights(self):
"""Weighted loss is correctly computed when weights is a python scalar."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithScalarTensorWeights(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41d75..bad0a596a7 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testInvalidInputShape(self):
x = constant_op.constant([[2.0, 1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10)
with self.assertRaisesWithPredicateMatch(
dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
[4.0, -2.0, -1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10, 1.0)
mapped_x1 = rffm.map(x1)
mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testSameOmegaReused(self):
x = constant_op.constant([[2.0, 1.0, 0.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 100)
mapped_x = rffm.map(x)
mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
total_absolute_error = 0.0
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
# Cache mappings so that they are not computed multiple times.
cached_mappings = dict((point, rffm.map(point))
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index e4db5f2e3c..e6a0b30567 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -38,7 +38,7 @@ class StatSummarizerTest(test.TestCase):
graph_def = graph.as_graph_def()
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
for _ in range(20):
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index e429d12e96..1c4e18dbda 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -32,7 +32,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[1], [10]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(
@@ -45,7 +45,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
@@ -57,7 +57,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = []
updates = []
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(init_val, input_data.eval())
@@ -67,7 +67,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
input_data = variables.Variable(init_val)
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
with self.assertRaisesOpError(
'Number of updates should be same as number of indices.'):
@@ -80,7 +80,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [1, 1]]
updates = [[100., 200., 300.], [400., 500., 600.]]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 1c9c81827e..e0f0c0d4ff 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -149,7 +149,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index 84e36146d5..832d34d60d 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -63,7 +63,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"brown"),
(b"jumps", b"fox"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -94,7 +94,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"fox"),
(b"jumps", b"jumps"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -105,7 +105,7 @@ class SkipGramOpsTest(test.TestCase):
# If emit_self_as_target is False (default), output will be empty.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, tokens.eval().size)
self.assertEqual(0, labels.eval().size)
@@ -117,7 +117,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"quick"),
(b"brown", b"brown"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -134,7 +134,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"the"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -150,7 +150,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -165,7 +165,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -196,7 +196,7 @@ class SkipGramOpsTest(test.TestCase):
(b"over", b"fox"),
(b"over", b"jumps"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
@@ -222,7 +222,7 @@ class SkipGramOpsTest(test.TestCase):
tokens_2, labels_2 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
[tokens_1, labels_1, tokens_2, labels_2])
@@ -244,7 +244,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"fox"),
(b"fox", b"brown"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
@@ -269,7 +269,7 @@ class SkipGramOpsTest(test.TestCase):
(2, 3),
(3, 2),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -286,7 +286,7 @@ class SkipGramOpsTest(test.TestCase):
for min_skips, max_skips in invalid_skips:
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=min_skips, max_skips=max_skips)
- with self.test_session() as sess, self.assertRaises(
+ with self.cached_session() as sess, self.assertRaises(
errors.InvalidArgumentError):
sess.run([tokens, labels])
@@ -338,7 +338,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
# No vocab_freq_table specified - output should be the same as input.
@@ -395,7 +395,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
@@ -464,7 +464,7 @@ class SkipGramOpsTest(test.TestCase):
(b"life", b"and"),
(b"and", b"life"),
])
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -510,7 +510,7 @@ class SkipGramOpsTest(test.TestCase):
(b"to", b"life"),
(b"life", b"to"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lookup_ops.tables_initializer().run()
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)