diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-11 20:19:50 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-11 20:19:50 +0800 |
commit | b2896c3cc3a0656b838f58975338d7dd309e3e62 (patch) | |
tree | 14f25741ab43c15e945e6044833c0ff44f11d83f /tensorflow/contrib/metrics | |
parent | 38f811077dd52820eaa3d5c684f41142de01c7eb (diff) | |
parent | e18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff) |
Merge remote-tracking branch 'upstream/master' into ENH/div_no_nan_treate_negative_as_zero
Diffstat (limited to 'tensorflow/contrib/metrics')
3 files changed, 324 insertions, 237 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index bfef0816aa..1ddd7e521b 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -2514,7 +2514,8 @@ def sparse_recall_at_top_k(labels, name=name_scope) -def _compute_recall_at_precision(tp, fp, fn, precision, name): +def _compute_recall_at_precision(tp, fp, fn, precision, name, + strict_mode=False): """Helper function to compute recall at a given `precision`. Args: @@ -2523,17 +2524,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): fn: The number of false negatives. precision: The precision for which the recall will be calculated. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + no smaller than the target precision, return the corresponding recall at + the threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) - tf_index = math_ops.argmin( - math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + if not strict_mode: + tf_index = math_ops.argmin( + math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + # Now, we have the implicit threshold, so compute the recall: + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) + else: + # We aim to find the threshold where the precision is minimum but no smaller + # than the target precision. + # The rationale: + # 1. Compute the difference between precisions (by different thresholds) and + # the target precision. + # 2. Take the reciprocal of the values by the above step. The intention is + # to make the positive values rank before negative values and also the + # smaller positives rank before larger positives. + tf_index = math_ops.argmax( + math_ops.div(1.0, precisions - precision + _EPSILON), + 0, + output_type=dtypes.int32) + + def _return_good_recall(): + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) - # Now, we have the implicit threshold, so compute the recall: - return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, - name) + return control_flow_ops.cond(precisions[tf_index] >= precision, + _return_good_recall, lambda: .0) def recall_at_precision(labels, @@ -2543,7 +2569,8 @@ def recall_at_precision(labels, num_thresholds=200, metrics_collections=None, updates_collections=None, - name=None): + name=None, + strict_mode=False): """Computes `recall` at `precision`. The `recall_at_precision` function creates four local variables, @@ -2575,6 +2602,11 @@ def recall_at_precision(labels, updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + above the target precision, return the corresponding recall at the + threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: recall: A scalar `Tensor` representing the recall at the given @@ -2603,10 +2635,11 @@ def recall_at_precision(labels, predictions, labels, thresholds, weights) recall = _compute_recall_at_precision(values['tp'], values['fp'], - values['fn'], precision, 'value') + values['fn'], precision, 'value', + strict_mode) update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], update_ops['fn'], precision, - 'update_op') + 'update_op', strict_mode) if metrics_collections: ops.add_to_collections(metrics_collections, recall) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py index 7acfc383eb..5777e64c29 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): # code used float32 for accumulation. num_updates = 71 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_updates): sess.run(update_op) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 1c2c17960a..955b83b44d 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual(1.65, sess.run(mean), 5) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual(1.65, sess.run(mean), 5) def test1dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5) def test1dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5) def test2dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5) def test2dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean)) def testMultiDimensional(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2)) _enqueue_vector( @@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean)) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5) def testWeighted1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5) def testWeighted2d_1(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5) def testWeighted2d_2(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase): (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2) accuracy, update_op = metrics.streaming_accuracy(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase): self.assertEqual(initial_accuracy, accuracy.eval()) def testMultipleUpdates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) @@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase): def testEffectivelyEquivalentSizes(self): predictions = array_ops.ones((40, 1)) labels = array_ops.ones((40,)) - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels) sess.run(variables.local_variables_initializer()) @@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase): weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]), 1) # shape 3, 1 - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels, weights) @@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase): dtype=dtypes_lib.int32, name='weights') feed_dict = {weights_placeholder: weights} - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels, weights_placeholder) @@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase): self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95) def testMultipleUpdatesWithWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) @@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tp.eval()) self.assertEqual(1, tp_update_op.eval()) @@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives( predictions, labels, weights=37.0) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tp.eval()) self.assertEqual(37.0, tp_update_op.eval()) @@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fn.eval()) self.assertEqual(2, fn_update_op.eval()) @@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives( predictions, labels, weights=((3.0,), (5.0,), (7.0,))) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fn.eval()) self.assertEqual(8.0, fn_update_op.eval()) @@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase): fp, fp_update_op = metrics.streaming_false_positives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fp.eval()) self.assertEqual(4, fp_update_op.eval()) @@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase): weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0, 29.0, 31.0))) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fp.eval()) self.assertEqual(42.0, fp_update_op.eval()) @@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tn.eval()) self.assertEqual(5, tn_update_op.eval()) @@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives( predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tn.eval()) self.assertEqual(15.0, tn_update_op.eval()) @@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), tp.eval()) self.assertAllEqual((3, 1, 0), tp_update_op.eval()) @@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), tp.eval()) self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval()) @@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), fn.eval()) self.assertAllEqual((0, 2, 3), fn_update_op.eval()) @@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): weights=((3.0,), (5.0,), (7.0,)), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), fn.eval()) self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval()) @@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): fp, fp_update_op = metrics.streaming_false_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), fp.eval()) self.assertAllEqual((7, 4, 2), fp_update_op.eval()) @@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): 29.0, 31.0)), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), fp.eval()) self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval()) @@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), tn.eval()) self.assertAllEqual((2, 5, 7), tn_update_op.eval()) @@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): weights=((0.0, 2.0, 3.0, 5.0),), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), tn.eval()) self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval()) @@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant(inputs) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1, sess.run(update_op)) self.assertAlmostEqual(1, precision.eval()) @@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, precision.eval()) @@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase): precision, update_op = metrics.streaming_precision( predictions, labels, weights=constant_op.constant([[2], [5]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 2.0 + 5.0 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) @@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase): precision, update_op = metrics.streaming_precision( predictions, labels, weights=constant_op.constant([[2], [5]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 2.0 + 5.0 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) @@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase): labels, weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 3.0 + 4.0 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) @@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase): labels, weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 3.0 + 4.0 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) @@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant(1 - inputs) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertAlmostEqual(0, precision.eval()) @@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant([0, 0, 0, 0]) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0.0, precision.eval()) @@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant(np_inputs) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, recall.eval()) @@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, recall.eval()) @@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase): recall, update_op = metrics.streaming_recall( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_tp = 2.0 + 5.0 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) @@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase): recall, update_op = metrics.streaming_recall( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_tp = 3.0 + 1.0 weighted_t = (2.0 + 3.0) + (4.0 + 1.0) @@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, recall.eval()) @@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase): labels = array_ops.zeros((1, 4)) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, recall.eval()) @@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant(np_inputs) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fpr.eval()) @@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, fpr.eval()) @@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase): fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fp = 2.0 + 5.0 weighted_f = (2.0 + 2.0) + (5.0 + 5.0) @@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase): fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fp = 1.0 + 3.0 weighted_f = (1.0 + 4.0) + (2.0 + 3.0) @@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, fpr.eval()) @@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase): labels = array_ops.ones((1, 4)) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fpr.eval()) @@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant(np_inputs) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fnr.eval()) @@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, fnr.eval()) @@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase): fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fn = 2.0 + 5.0 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) @@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase): fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fn = 2.0 + 4.0 weighted_t = (2.0 + 3.0) + (1.0 + 4.0) @@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, fnr.eval()) @@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase): labels = array_ops.zeros((1, 4)) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fnr.eval()) @@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase): points, update_op = metric_ops.streaming_curve_points( labels, predictions=predictions, curve=curve) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) @@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase): self._testValueTensorIsIdempotent(curve='PR') def _testCase(self, labels, predictions, curve, expected_points): - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( predictions, dtype=dtypes_lib.float32) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32) @@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(initial_auc, auc.eval(), 5) def testPredictionsOutOfRange(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase): def allCorrectAsExpected(self, curve): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve) @@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase): self.assertEqual(1, auc.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.5, auc.eval()) def testWeighted1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.5, auc.eval(), 5) def testWeighted2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.7, auc.eval(), 5) def testAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) @@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], shape=(1, 7), @@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], shape=(1, 7), @@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_auc(predictions, labels) @@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0, auc.eval()) def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) auc, update_op = metrics.streaming_auc(predictions, labels) @@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) def testRecallOneAndPrecisionOneGivesOnePRAUC(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.ones([4], dtype=dtypes_lib.float32) labels = array_ops.ones([4]) auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') @@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase): np.random.exponential(scale=1.0, size=num_samples)): expected_auc = _np_auc(predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: enqueue_ops = [[] for i in range(num_batches)] tf_predictions = _enqueue_as_batches(predictions, enqueue_ops) tf_labels = _enqueue_as_batches(labels, enqueue_ops) @@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase): labels = random_ops.random_uniform( (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. for _ in xrange(10): @@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(initial_auc, auc.eval(), 5) def testAllLabelsOnes(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1., 1., 1.]) labels = constant_op.constant([1, 1, 1]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(0, auc.eval()) def testAllLabelsZeros(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1., 1., 1.]) labels = constant_op.constant([0, 0, 0]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(0, auc.eval()) def testNonZeroOnePredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32) labels = constant_op.constant([1, 0, 1, 0]) @@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs) labels = constant_op.constant(inputs) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(1, auc.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0, 1, 0]) labels = constant_op.constant([0, 1, 1, 0]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0, auc.eval()) def testExceptionOnIncompatibleShapes(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.ones([5]) labels = array_ops.zeros([6]) with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): @@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) def testExceptionOnGreaterThanOneLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) labels = constant_op.constant([2, 1, 0]) _, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) def testExceptionOnNegativeLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) labels = constant_op.constant([1, 0, -1]) _, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase): collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval()) def testAUCPRReverseIncreasingPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 1]) @@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5) def testAUCPRJumbledPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1]) @@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6) def testAUCPRPredictionsLessThanHalf(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], shape=(1, 7), @@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase): auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions, weights=tf_weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): expected_result: The expected result (dict) that maps to tensors. weights: Optional weights tensor. """ - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( predictions, dtype=dtypes_lib.float32) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64) @@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): dtype=dtypes_lib.float32) auc, update_op = metrics.auc_with_confidence_intervals(tf_labels, tf_predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): self.assertAllClose(expected_auc, auc.auc.eval()) def testExceptionOnFloatLabels(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([0.7, 0, 1, 0, 1]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): self.assertRaises(TypeError, sess.run(update_op)) def testExceptionOnGreaterThanOneLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([2, 1, 0, 1, 0]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): sess.run(update_op) def testExceptionOnNegativeLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([1, 0, -1, 1, 0]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): result, update_op = metric_ops.precision_recall_at_equal_thresholds( labels=labels, predictions=predictions) - with self.test_session() as sess: + with self.cached_session() as sess: # Run several updates. sess.run(variables.local_variables_initializer()) for _ in range(3): @@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): default from assertAllClose. weights: Optional weights tensor. """ - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant(predictions, dtype=dtype) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) weights_tensor = None @@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, specificity.eval()) @@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1.0, sess.run(update_op)) self.assertAlmostEqual(1.0, specificity.eval()) @@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) @@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, weights=weights, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) @@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, weights=weights, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op)) @@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): sensitivity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, specificity.eval()) @@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, specificity.eval()) @@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) self.assertAlmostEqual(0.6, specificity.eval()) @@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, weights=weights, specificity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.675, sess.run(update_op)) self.assertAlmostEqual(0.675, specificity.eval()) @@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertEqual(1, rec.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0, rec.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval()) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase): fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertEqual(0, fpr.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(1, fpr.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=1.0) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, recall.eval()) @@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, recall.eval()) @@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) target_recall = 2.0 / 3.0 self.assertAlmostEqual(target_recall, sess.run(update_op)) @@ -3461,12 +3461,66 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, weights=weights, precision=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) target_recall = 2.0 / 3.0 self.assertAlmostEqual(target_recall, sess.run(update_op)) self.assertAlmostEqual(target_recall, recall.eval()) + def _test_strict_mode(self, strict_mode, target_precision, expected_recall): + num_thresholds = 11 + predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1] + labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1] + # Resulting thresholds and the corresponding precision and recall values at + # each threshold: + # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9] + # precisions: [0.3 0.2 0.1 0 0 0 0 0 0] + # recalls: [1.0 0.7 0.3 0 0 0 0 0 0] + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + labels, + predictions, + num_thresholds=num_thresholds, + precision=target_precision, + strict_mode=strict_mode) + + with self.cached_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expected_recall, sess.run(update_op)) + self.assertAlmostEqual(expected_recall, recall.eval()) + + def testStrictMode_Off(self): + # strict_mode is turned off and return the recall at the threshold where the + # precision (0.3) is closest to target precision (0.9). The recall + # corresponding to the threshold is 1.0. + self._test_strict_mode( + strict_mode=False, target_precision=0.9, expected_recall=1.0) + + def testStrictMode_OnAndFail(self): + # strict_mode is turned on and we fail to reach the target precision at any + # threshold. + # Target precision: 0.9 + # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9] + # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1] + # Max index: 3 and corresponding precision is: 0 which is smaller than + # target precsion 0.9. As a result, the expected recall is 0. + self._test_strict_mode( + strict_mode=True, target_precision=0.9, expected_recall=.0) + + def testStrictMode_OnAndSucceed(self): + # strict_mode is on and we can reach the target precision at certain + # threshold. + # Target precision: 0.2 + # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2] + # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0] + # Max index: 1 and corresponding precision is: 0.2 which is no smaller than + # target precsion 0.2. In this case, we return the recall at index 1, which + # is 2.0/3 (0.7). + self._test_strict_mode( + strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3) + class PrecisionAtRecallTest(test.TestCase): @@ -3511,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3531,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, precision.eval()) @@ -3545,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(sess.run(label_prior), sess.run(update_op)) self.assertEqual(sess.run(label_prior), precision.eval()) @@ -3560,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, precision.eval()) @@ -3575,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(2.0/3, sess.run(update_op)) self.assertAlmostEqual(2.0/3, precision.eval()) @@ -3594,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.8, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(34.0/43, sess.run(update_op)) self.assertAlmostEqual(34.0/43, precision.eval()) @@ -3643,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase): fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3658,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -3671,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertEqual(0, fnr.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -3687,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -3700,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1, fnr.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3720,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3740,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3758,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval()) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3798,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3886,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.25, sess.run(update_op)) self.assertEqual(0.25, recall.eval()) @@ -3904,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.5, sess.run(update_op)) self.assertEqual(0.5, recall.eval()) @@ -3922,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) @@ -3946,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase): k=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) @@ -4068,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase): self.assertAlmostEqual(expected, metric.eval()) def test_top_k_rank_invalid(self): - with self.test_session(): + with self.cached_session(): # top_k_predictions has rank < 2. top_k_predictions = [9, 4, 6, 2, 0] sp_labels = sparse_tensor.SparseTensorValue( @@ -4615,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] labels = [[0, 0, 0, 1], [0, 0, 1, 0]] expected_precision = 0.5 - with self.test_session(): + with self.cached_session(): _, precision = metrics.streaming_sparse_precision_at_k( predictions=constant_op.constant(predictions, dtypes_lib.float32), labels=_binary_2d_label_to_sparse_value(labels), @@ -5320,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] labels = [[0, 0, 1, 0], [0, 0, 0, 1]] expected_recall = 0.5 - with self.test_session(): + with self.cached_session(): _, recall = metrics.streaming_sparse_recall_at_k( predictions=constant_op.constant(predictions, dtypes_lib.float32), labels=_binary_2d_label_to_sparse_value(labels), @@ -5364,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_absolute_error( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5386,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_absolute_error( predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(3, sess.run(update_op)) self.assertEqual(3, error.eval()) @@ -5430,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5455,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer=labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(expected_error, sess.run(update_op)) self.assertEqual(expected_error, error.eval()) @@ -5471,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer=array_ops.zeros_like(labels)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.0, sess.run(update_op)) self.assertEqual(0.0, error.eval()) @@ -5509,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): labels = random_ops.random_normal((10, 3), seed=2) error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5527,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -5540,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(6, sess.run(update_op)) self.assertEqual(6, error.eval()) @@ -5555,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error( predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(13, sess.run(update_op)) self.assertEqual(13, error.eval()) def testMultipleBatchesOfSizeOne(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5586,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(208.0 / 6, error.eval(), 5) def testMetricsComputedConcurrently(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates one set of predictions. preds_queue0 = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5629,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(79.0 / 6, mse1, 5) def testMultipleMetricsOnMultipleBatchesOfSizeOne(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5691,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_root_mean_squared_error( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5704,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertEqual(initial_error, error.eval()) def testSingleUpdateZeroError(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( 0.0, shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32) @@ -5718,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertEqual(0, rmse.eval()) def testSingleUpdateWithError(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5732,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5788,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase): predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5 cov, update_op = metrics.streaming_covariance(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5801,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertEqual(initial_cov, cov.eval()) def testSingleUpdateIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = math_ops.to_float(math_ops.range(10)) labels = math_ops.to_float(math_ops.range(10)) @@ -5813,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval(), 5) def testSingleUpdateNonIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5827,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval()) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5845,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval()) def testMultiUpdateWithErrorNoWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -5879,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase): prev_expected_cov = expected_cov def testMultiUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -5969,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase): pearson_r, update_op = metrics.streaming_pearson_correlation( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5982,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertEqual(initial_r, pearson_r.eval()) def testSingleUpdateIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = math_ops.to_float(math_ops.range(10)) labels = math_ops.to_float(math_ops.range(10)) @@ -5995,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval(), 5) def testSingleUpdateNonIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -6010,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval()) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = np.array([2, 4, 6, 8]) labels = np.array([1, 3, 2, 7]) weights = np.array([0, 1, 3, 1]) @@ -6031,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval()) def testMultiUpdateWithErrorNoWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6066,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase): prev_expected_r = expected_r def testMultiUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6108,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase): prev_expected_r = expected_r def testMultiUpdateWithErrorAndSingletonBatches(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6189,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -6212,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -6229,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1, sess.run(update_op), 5) self.assertAlmostEqual(1, error.eval(), 5) @@ -6251,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1.0, sess.run(update_op), 5) self.assertAlmostEqual(1.0, error.eval(), 5) @@ -6270,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -6289,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.5, update_op.eval()) self.assertEqual(1.5, error.eval()) @@ -6324,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testOneUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) @@ -6344,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase): self.assertAlmostEqual(0.0, pcnt2, 5) def testSomePresentOneUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) weights = constant_op.constant( @@ -6421,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase): miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes=num_classes) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -6435,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase): def testMultipleUpdates(self): num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) @@ -6467,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase): def testMultipleUpdatesWithWeights(self): num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) @@ -6515,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase): # one class, and thus there is one row and one column with # zero entries in the confusion matrix. num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. # There is no prediction for class 2. preds_queue = data_flow_ops.FIFOQueue( @@ -6557,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase): constant_op.constant(1, shape=[7]) ], 0) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6570,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase): predictions = array_ops.zeros([40]) labels = array_ops.zeros([40]) num_classes = 1 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6581,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase): predictions = array_ops.zeros([40]) labels = array_ops.ones([40]) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6603,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase): constant_op.constant(1, shape=[8]), constant_op.constant(0, shape=[1]) ], 0) - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) @@ -6618,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase): [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1], [1, 1, 2, 0, 0, 0]]]) num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6630,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase): labels = constant_op.constant([0]) predictions = constant_op.constant([0]) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6644,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase): [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 0]]]) num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6679,7 +6733,7 @@ class StreamingConcatTest(test.TestCase): def testNextArraySize(self): next_array_size = metric_ops._next_array_size # pylint: disable=protected-access - with self.test_session(): + with self.cached_session(): self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2) self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4) self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4) @@ -6687,7 +6741,7 @@ class StreamingConcatTest(test.TestCase): self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8) def testStreamingConcat(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.int32, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6704,7 +6758,7 @@ class StreamingConcatTest(test.TestCase): self.assertAllEqual(np.arange(10), concatenated.eval()) def testStreamingConcatStringValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.string, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6723,7 +6777,7 @@ class StreamingConcatTest(test.TestCase): concatenated.eval()) def testStreamingConcatMaxSize(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = math_ops.range(3) concatenated, update_op = metrics.streaming_concat(values, max_size=5) sess.run(variables.local_variables_initializer()) @@ -6740,7 +6794,7 @@ class StreamingConcatTest(test.TestCase): self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval()) def testStreamingConcat2D(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.reshape(math_ops.range(3), (3, 1)) concatenated, update_op = metrics.streaming_concat(values, axis=-1) sess.run(variables.local_variables_initializer()) @@ -6763,7 +6817,7 @@ class StreamingConcatTest(test.TestCase): array_ops.placeholder(dtypes_lib.float32, [None, None])) def testStreamingConcatReset(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.int32, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6791,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase): metrics.streaming_mean(values)) self.assertEqual(len(value_tensors), 1) self.assertEqual(len(update_ops), 1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, update_ops[0].eval()) self.assertEqual(1, value_tensors[0].eval()) @@ -6804,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase): metrics.streaming_mean_squared_error(predictions, labels)) self.assertEqual(len(value_tensors), 2) self.assertEqual(len(update_ops), 2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(2, update_ops[0].eval()) self.assertEqual(4, update_ops[1].eval()) @@ -6825,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase): self.assertEqual(2, len(names_to_values)) self.assertEqual(2, len(names_to_updates)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(2, names_to_updates['m1'].eval()) self.assertEqual(4, names_to_updates['m2'].eval()) @@ -6860,7 +6914,7 @@ class CountTest(test.TestCase): self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor)) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -6877,7 +6931,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(8.0, sess.run(result), 5) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -6898,7 +6952,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(8.0, sess.run(result), 5) def test1dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -6925,7 +6979,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(3.4, result.eval(), 5) def test1dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -6947,7 +7001,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(3.4, result.eval(), 5) def test2dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -6974,7 +7028,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(4.1, result.eval(), 5) def test2dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -7047,7 +7101,7 @@ class CohenKappaTest(test.TestCase): (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) kappa, update_op = metrics.cohen_kappa(labels, predictions, 3) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -7081,7 +7135,7 @@ class CohenKappaTest(test.TestCase): for dtype in dtypes: for shape in shapes: for weight in weights: - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( np.reshape(predictions, shape), dtype=dtype) labels_tensor = constant_op.constant( @@ -7102,7 +7156,7 @@ class CohenKappaTest(test.TestCase): # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs) expect = 1.0 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) @@ -7121,7 +7175,7 @@ class CohenKappaTest(test.TestCase): # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) expect = -0.333333333333 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) labels = constant_op.constant(labels) kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) @@ -7139,7 +7193,7 @@ class CohenKappaTest(test.TestCase): # labels, predictions, sample_weight=weights) expect = 0.453466583385 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) labels = constant_op.constant(labels) kappa, update_op = metrics.cohen_kappa( @@ -7164,7 +7218,7 @@ class CohenKappaTest(test.TestCase): weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,)) kappa, update_op = metrics.cohen_kappa( labels_t, predictions_t, num_classes, weights=weights_t) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for idx in range(0, num_samples, batch_size): @@ -7202,7 +7256,7 @@ class CohenKappaTest(test.TestCase): def testConditionalPackingOptimization(self): placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) values, update_op = metric_ops.streaming_concat(placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for feed in range(10): sess.run(update_op, feed_dict={placeholder: [feed]}) |