aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kernel_methods
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 00:02:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 00:07:39 -0700
commit2952f5134905af795ba90ae1eb97e39091ba9843 (patch)
treef73bc5cd0342d9449114bd933863c2aa55610aa2 /tensorflow/contrib/kernel_methods
parentcf047f7755f3400ee128db2571042091fe9f8314 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 213944355
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py38
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py12
2 files changed, 25 insertions, 25 deletions
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 72507539f8..4d5cc24ce0 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLogitsShape(self):
"""An error is raised when logits have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2,))
labels = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsShape(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(1, 1, 2))
with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidWeightsShape(self):
"""An error is raised when weights have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2,))
weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsDtype(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], dtype=dtypes.float32)
with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
"""An error is raised when weights are None."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0])
with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesSameRank(self):
"""Error raised when weights and labels have same ranks, different sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
"""Error raised when weights and labels have different ranks and sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2, 1))
weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testOutOfRangeLabels(self):
"""An error is raised when labels are not in [0, num_classes)."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt32Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt64Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
[-0.5, 0.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
]
for batch_size, num_classes in logits_shapes:
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(
dtypes.float32, shape=(batch_size, num_classes))
labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
[1.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictions(self):
"""Loss is >0 when an incorrect class has higher logits than true class."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
[0.5, -1.8, 2.0]])
labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsColumnLabels(self):
"""Same as above but labels is a rank-2 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsZeroWeights(self):
"""Loss is 0 when all weights are missing even if predictions are wrong."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithPythonScalarWeights(self):
"""Weighted loss is correctly computed when weights is a python scalar."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithScalarTensorWeights(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41d75..bad0a596a7 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testInvalidInputShape(self):
x = constant_op.constant([[2.0, 1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10)
with self.assertRaisesWithPredicateMatch(
dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
[4.0, -2.0, -1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10, 1.0)
mapped_x1 = rffm.map(x1)
mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testSameOmegaReused(self):
x = constant_op.constant([[2.0, 1.0, 0.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 100)
mapped_x = rffm.map(x)
mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
total_absolute_error = 0.0
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
# Cache mappings so that they are not computed multiple times.
cached_mappings = dict((point, rffm.map(point))