aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 20:19:50 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 20:19:50 +0800
commitb2896c3cc3a0656b838f58975338d7dd309e3e62 (patch)
tree14f25741ab43c15e945e6044833c0ff44f11d83f /tensorflow/contrib/metrics
parent38f811077dd52820eaa3d5c684f41142de01c7eb (diff)
parente18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff)
Merge remote-tracking branch 'upstream/master' into ENH/div_no_nan_treate_negative_as_zero
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py51
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py508
3 files changed, 324 insertions, 237 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index bfef0816aa..1ddd7e521b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2514,7 +2514,8 @@ def sparse_recall_at_top_k(labels,
name=name_scope)
-def _compute_recall_at_precision(tp, fp, fn, precision, name):
+def _compute_recall_at_precision(tp, fp, fn, precision, name,
+ strict_mode=False):
"""Helper function to compute recall at a given `precision`.
Args:
@@ -2523,17 +2524,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name):
fn: The number of false negatives.
precision: The precision for which the recall will be calculated.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ no smaller than the target precision, return the corresponding recall at
+ the threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
The recall at a given `precision`.
"""
precisions = math_ops.div(tp, tp + fp + _EPSILON)
- tf_index = math_ops.argmin(
- math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ if not strict_mode:
+ tf_index = math_ops.argmin(
+ math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ # Now, we have the implicit threshold, so compute the recall:
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
+ else:
+ # We aim to find the threshold where the precision is minimum but no smaller
+ # than the target precision.
+ # The rationale:
+ # 1. Compute the difference between precisions (by different thresholds) and
+ # the target precision.
+ # 2. Take the reciprocal of the values by the above step. The intention is
+ # to make the positive values rank before negative values and also the
+ # smaller positives rank before larger positives.
+ tf_index = math_ops.argmax(
+ math_ops.div(1.0, precisions - precision + _EPSILON),
+ 0,
+ output_type=dtypes.int32)
+
+ def _return_good_recall():
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
- # Now, we have the implicit threshold, so compute the recall:
- return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
- name)
+ return control_flow_ops.cond(precisions[tf_index] >= precision,
+ _return_good_recall, lambda: .0)
def recall_at_precision(labels,
@@ -2543,7 +2569,8 @@ def recall_at_precision(labels,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
- name=None):
+ name=None,
+ strict_mode=False):
"""Computes `recall` at `precision`.
The `recall_at_precision` function creates four local variables,
@@ -2575,6 +2602,11 @@ def recall_at_precision(labels,
updates_collections: An optional list of collections that `update_op` should
be added to.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ above the target precision, return the corresponding recall at the
+ threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
recall: A scalar `Tensor` representing the recall at the given
@@ -2603,10 +2635,11 @@ def recall_at_precision(labels,
predictions, labels, thresholds, weights)
recall = _compute_recall_at_precision(values['tp'], values['fp'],
- values['fn'], precision, 'value')
+ values['fn'], precision, 'value',
+ strict_mode)
update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
update_ops['fn'], precision,
- 'update_op')
+ 'update_op', strict_mode)
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
index 7acfc383eb..5777e64c29 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
# code used float32 for accumulation.
num_updates = 71
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_updates):
sess.run(update_op)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 1c2c17960a..955b83b44d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2)
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
sess.run(variables.local_variables_initializer())
@@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights)
@@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights_placeholder)
@@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(1, tp_update_op.eval())
@@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(37.0, tp_update_op.eval())
@@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(2, fn_update_op.eval())
@@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(8.0, fn_update_op.eval())
@@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(4, fp_update_op.eval())
@@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase):
weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
29.0, 31.0)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(42.0, fp_update_op.eval())
@@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(5, tn_update_op.eval())
@@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(15.0, tn_update_op.eval())
@@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
@@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fpr.eval())
@@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 2.0 + 5.0
weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 1.0 + 3.0
weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
@@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fpr.eval())
@@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase):
labels = array_ops.ones((1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fnr.eval())
@@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 4.0
weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
@@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fnr.eval())
@@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase):
labels = array_ops.zeros((1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase):
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
@@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testPredictionsOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve)
@@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
def testAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
@@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase):
np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
for _ in xrange(10):
@@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testAllLabelsOnes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([1, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testAllLabelsZeros(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([0, 0, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testNonZeroOnePredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
@@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0, 1, 0])
labels = constant_op.constant([0, 1, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testExceptionOnIncompatibleShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([5])
labels = array_ops.zeros([6])
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
@@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval())
def testAUCPRReverseIncreasingPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1])
@@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
def testAUCPRJumbledPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
@@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
def testAUCPRPredictionsLessThanHalf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase):
auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
tf_predictions,
weights=tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
expected_result: The expected result (dict) that maps to tensors.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
@@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
dtype=dtypes_lib.float32)
auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertAllClose(expected_auc, auc.auc.eval())
def testExceptionOnFloatLabels(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([0.7, 0, 1, 0, 1])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertRaises(TypeError, sess.run(update_op))
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
result, update_op = metric_ops.precision_recall_at_equal_thresholds(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run several updates.
sess.run(variables.local_variables_initializer())
for _ in range(3):
@@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
default from assertAllClose.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(predictions, dtype=dtype)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
weights_tensor = None
@@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertEqual(0, fpr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fpr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=1.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, recall.eval())
@@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, recall.eval())
@@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3461,12 +3461,66 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, weights=weights, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
self.assertAlmostEqual(target_recall, recall.eval())
+ def _test_strict_mode(self, strict_mode, target_precision, expected_recall):
+ num_thresholds = 11
+ predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1]
+ labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1]
+ # Resulting thresholds and the corresponding precision and recall values at
+ # each threshold:
+ # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
+ # precisions: [0.3 0.2 0.1 0 0 0 0 0 0]
+ # recalls: [1.0 0.7 0.3 0 0 0 0 0 0]
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ labels,
+ predictions,
+ num_thresholds=num_thresholds,
+ precision=target_precision,
+ strict_mode=strict_mode)
+
+ with self.cached_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expected_recall, sess.run(update_op))
+ self.assertAlmostEqual(expected_recall, recall.eval())
+
+ def testStrictMode_Off(self):
+ # strict_mode is turned off and return the recall at the threshold where the
+ # precision (0.3) is closest to target precision (0.9). The recall
+ # corresponding to the threshold is 1.0.
+ self._test_strict_mode(
+ strict_mode=False, target_precision=0.9, expected_recall=1.0)
+
+ def testStrictMode_OnAndFail(self):
+ # strict_mode is turned on and we fail to reach the target precision at any
+ # threshold.
+ # Target precision: 0.9
+ # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9]
+ # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1]
+ # Max index: 3 and corresponding precision is: 0 which is smaller than
+ # target precsion 0.9. As a result, the expected recall is 0.
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.9, expected_recall=.0)
+
+ def testStrictMode_OnAndSucceed(self):
+ # strict_mode is on and we can reach the target precision at certain
+ # threshold.
+ # Target precision: 0.2
+ # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2]
+ # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0]
+ # Max index: 1 and corresponding precision is: 0.2 which is no smaller than
+ # target precsion 0.2. In this case, we return the recall at index 1, which
+ # is 2.0/3 (0.7).
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3)
+
class PrecisionAtRecallTest(test.TestCase):
@@ -3511,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3531,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, precision.eval())
@@ -3545,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(sess.run(label_prior), sess.run(update_op))
self.assertEqual(sess.run(label_prior), precision.eval())
@@ -3560,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, precision.eval())
@@ -3575,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(2.0/3, sess.run(update_op))
self.assertAlmostEqual(2.0/3, precision.eval())
@@ -3594,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(34.0/43, sess.run(update_op))
self.assertAlmostEqual(34.0/43, precision.eval())
@@ -3643,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3658,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3671,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertEqual(0, fnr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3687,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3700,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fnr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3720,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3740,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3758,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3798,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3886,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.25, sess.run(update_op))
self.assertEqual(0.25, recall.eval())
@@ -3904,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
@@ -3922,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -3946,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase):
k=2,
weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -4068,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
self.assertAlmostEqual(expected, metric.eval())
def test_top_k_rank_invalid(self):
- with self.test_session():
+ with self.cached_session():
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
@@ -4615,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
expected_precision = 0.5
- with self.test_session():
+ with self.cached_session():
_, precision = metrics.streaming_sparse_precision_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5320,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
- with self.test_session():
+ with self.cached_session():
_, recall = metrics.streaming_sparse_recall_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5364,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5386,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -5430,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5455,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -5471,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -5509,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5527,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -5540,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -5555,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5586,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5629,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5691,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_root_mean_squared_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5704,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -5718,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5732,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5788,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase):
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
cov, update_op = metrics.streaming_covariance(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5801,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertEqual(initial_cov, cov.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5813,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5827,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5845,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5879,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase):
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5969,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase):
pearson_r, update_op = metrics.streaming_pearson_correlation(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5982,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertEqual(initial_r, pearson_r.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5995,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -6010,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = np.array([2, 4, 6, 8])
labels = np.array([1, 3, 2, 7])
weights = np.array([0, 1, 3, 1])
@@ -6031,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6066,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6108,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6189,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6212,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6229,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -6251,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -6270,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6289,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -6324,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -6344,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -6421,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase):
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6435,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6467,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6515,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -6557,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[7])
], 0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6570,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6581,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6603,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -6618,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6630,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6644,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6679,7 +6733,7 @@ class StreamingConcatTest(test.TestCase):
def testNextArraySize(self):
next_array_size = metric_ops._next_array_size # pylint: disable=protected-access
- with self.test_session():
+ with self.cached_session():
self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
@@ -6687,7 +6741,7 @@ class StreamingConcatTest(test.TestCase):
self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
def testStreamingConcat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6704,7 +6758,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual(np.arange(10), concatenated.eval())
def testStreamingConcatStringValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.string, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6723,7 +6777,7 @@ class StreamingConcatTest(test.TestCase):
concatenated.eval())
def testStreamingConcatMaxSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = math_ops.range(3)
concatenated, update_op = metrics.streaming_concat(values, max_size=5)
sess.run(variables.local_variables_initializer())
@@ -6740,7 +6794,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
def testStreamingConcat2D(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.reshape(math_ops.range(3), (3, 1))
concatenated, update_op = metrics.streaming_concat(values, axis=-1)
sess.run(variables.local_variables_initializer())
@@ -6763,7 +6817,7 @@ class StreamingConcatTest(test.TestCase):
array_ops.placeholder(dtypes_lib.float32, [None, None]))
def testStreamingConcatReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6791,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean(values))
self.assertEqual(len(value_tensors), 1)
self.assertEqual(len(update_ops), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, update_ops[0].eval())
self.assertEqual(1, value_tensors[0].eval())
@@ -6804,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean_squared_error(predictions, labels))
self.assertEqual(len(value_tensors), 2)
self.assertEqual(len(update_ops), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, update_ops[0].eval())
self.assertEqual(4, update_ops[1].eval())
@@ -6825,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(2, len(names_to_values))
self.assertEqual(2, len(names_to_updates))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, names_to_updates['m1'].eval())
self.assertEqual(4, names_to_updates['m2'].eval())
@@ -6860,7 +6914,7 @@ class CountTest(test.TestCase):
self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6877,7 +6931,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6898,7 +6952,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6925,7 +6979,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -6947,7 +7001,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6974,7 +7028,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7047,7 +7101,7 @@ class CohenKappaTest(test.TestCase):
(10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -7081,7 +7135,7 @@ class CohenKappaTest(test.TestCase):
for dtype in dtypes:
for shape in shapes:
for weight in weights:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
np.reshape(predictions, shape), dtype=dtype)
labels_tensor = constant_op.constant(
@@ -7102,7 +7156,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
expect = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7121,7 +7175,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
expect = -0.333333333333
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7139,7 +7193,7 @@ class CohenKappaTest(test.TestCase):
# labels, predictions, sample_weight=weights)
expect = 0.453466583385
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(
@@ -7164,7 +7218,7 @@ class CohenKappaTest(test.TestCase):
weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for idx in range(0, num_samples, batch_size):
@@ -7202,7 +7256,7 @@ class CohenKappaTest(test.TestCase):
def testConditionalPackingOptimization(self):
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
values, update_op = metric_ops.streaming_concat(placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for feed in range(10):
sess.run(update_op, feed_dict={placeholder: [feed]})