aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 13:10:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 13:13:31 -0700
commit4254b2ca729858d5bff2bbd570b4f7b02d42fd35 (patch)
tree33f763e99e2e1a00e288d8538f3ad099f9da77b4 /tensorflow/contrib/metrics
parent642a043de4901ddbf305db105168b8908adfe99e (diff)
Splits testLargeCase in metric_ops_test into a dedicated file for slow-running tests and re-enables it as a 'large' test.
PiperOrigin-RevId: 200440883
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/BUILD24
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py66
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py28
3 files changed, 90 insertions, 28 deletions
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 4f2c82ca23..3f81c9ccea 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -97,3 +97,27 @@ py_test(
"//third_party/py/numpy",
],
)
+
+py_test(
+ name = "metric_ops_large_test",
+ size = "large",
+ srcs = ["python/ops/metric_ops_large_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["noasan"], # times out b/63678675
+ deps = [
+ ":metrics_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
new file mode 100644
index 0000000000..7acfc383eb
--- /dev/null
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Large tests for metric_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.contrib.metrics.python.ops import metric_ops
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testLargeCase(self):
+ shape = [32, 512, 256, 1]
+ predictions = random_ops.random_uniform(
+ shape, 0.0, 1.0, dtype=dtypes_lib.float32)
+ labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)
+
+ result, update_op = metric_ops.precision_recall_at_equal_thresholds(
+ labels=labels, predictions=predictions, num_thresholds=201)
+ # Run many updates, enough to cause highly inaccurate values if the
+ # code used float32 for accumulation.
+ num_updates = 71
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_updates):
+ sess.run(update_op)
+
+ prdata = sess.run(result)
+
+ # Since we use random values, we won't know the tp/fp/tn/fn values, but
+ # tp and fp at threshold 0 should be the total number of positive and
+ # negative labels, hence their sum should be total number of pixels.
+ expected_value = 1.0 * np.product(shape) * num_updates
+ got_value = prdata.tp[0] + prdata.fp[0]
+ # They should be at least within 1.
+ self.assertNear(got_value, expected_value, 1.0)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b13f08a37d..db4b530ce7 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2391,34 +2391,6 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
for _ in range(3):
self._testResultsEqual(initial_result, result)
- def testLargeCase(self):
- self.skipTest("Test consistently timing out")
- shape = [32, 512, 256, 1]
- predictions = random_ops.random_uniform(
- shape, 0.0, 1.0, dtype=dtypes_lib.float32)
- labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)
-
- result, update_op = metric_ops.precision_recall_at_equal_thresholds(
- labels=labels, predictions=predictions, num_thresholds=201)
- # Run many updates, enough to cause highly inaccurate values if the
- # code used float32 for accumulation.
- num_updates = 71
-
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- for _ in xrange(num_updates):
- sess.run(update_op)
-
- prdata = sess.run(result)
-
- # Since we use random values, we won't know the tp/fp/tn/fn values, but
- # tp and fp at threshold 0 should be the total number of positive and
- # negative labels, hence their sum should be total number of pixels.
- expected_value = 1.0 * np.product(shape) * num_updates
- got_value = prdata.tp[0] + prdata.fp[0]
- # They should be at least within 1.
- self.assertNear(got_value, expected_value, 1.0)
-
def _testCase(self,
predictions,
labels,