aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:36:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 14:49:41 -0700
commit890e16594a005fe703a5556530b0dc3e6527fa47 (patch)
tree99140efb13f392ae13a58f08c08754c61bf66f13 /tensorflow/contrib/losses
parent132babebf5b1026cb33cad7c4eb7e03810c2acdf (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336321
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py214
1 files changed, 107 insertions, 107 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 2a442a8fc8..c0aec09778 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -43,68 +43,68 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.absolute_difference(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2,])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -117,12 +117,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -141,7 +141,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -154,7 +154,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -166,7 +166,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -179,7 +179,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -191,7 +191,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -203,12 +203,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -223,7 +223,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -253,7 +253,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights = [2.3, 2.4, 2.5]
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -268,7 +268,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights_placeholder = array_ops.placeholder(
dtypes.float32, shape=[None, None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -280,12 +280,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -320,7 +320,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -331,7 +331,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -342,7 +342,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -353,7 +353,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -363,7 +363,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -374,7 +374,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -384,7 +384,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -394,7 +394,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -404,12 +404,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -422,7 +422,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -435,7 +435,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -448,7 +448,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -462,7 +462,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -484,7 +484,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -498,7 +498,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None, None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -506,7 +506,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -537,7 +537,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -546,7 +546,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -582,11 +582,11 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -608,7 +608,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -641,33 +641,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = loss_ops.log_loss(tf_predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = loss_ops.log_loss(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -675,7 +675,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -685,7 +685,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -695,7 +695,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -706,7 +706,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -715,7 +715,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -724,12 +724,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._predictions, self._labels, weights)
@@ -742,7 +742,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -756,7 +756,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -769,7 +769,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -780,35 +780,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = loss_ops.hinge_loss(logits, labels).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = loss_ops.hinge_loss(logits, labels)
self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -817,7 +817,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -834,62 +834,62 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_squared_error(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2,])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -914,7 +914,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 9.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -925,14 +925,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
def testGradientWithZeroWeight(self):
@@ -954,7 +954,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -966,7 +966,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -976,7 +976,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -986,7 +986,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
@@ -998,7 +998,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=tf_predictions,
labels=tf_labels,
weights=constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1015,7 +1015,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
def testZeroLossWithOneDimBatchZeroWeights(self):
@@ -1025,7 +1025,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
@@ -1041,7 +1041,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1056,7 +1056,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testLossIsAssociativeAcrossBatchElements(self):
@@ -1087,7 +1087,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=array_ops.concat([predictions0, predictions1], 0),
labels=array_ops.concat([labels0, labels1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1115,7 +1115,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1128,7 +1128,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1136,7 +1136,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1154,7 +1154,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1163,7 +1163,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=constant_op.constant([1, 0, 0]))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1173,12 +1173,12 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testValueErrorThrownWithShapelessPlaceholder(self):
tf_predictions = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=tf_predictions,
@@ -1196,7 +1196,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1206,7 +1206,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3,)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1215,7 +1215,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1228,7 +1228,7 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss = loss_ops.compute_weighted_loss(losses)
self.assertTrue(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1243,7 +1243,7 @@ class AddLossTest(test.TestCase):
loss_ops.add_loss(math_ops.reduce_mean(losses))
self.assertTrue(loss_ops.get_losses())
total_loss = loss_ops.get_total_loss()
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1254,7 +1254,7 @@ class AddLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None)
self.assertFalse(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
def testNoCollectLosses(self):