aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-10-03 13:39:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:50:49 -0700
commitce9a5d143f89a37ab029a29c62433883323987e8 (patch)
tree0953913adeb36a4173bbb5b4e582c925bcab875d /tensorflow/contrib
parent43073e9d4dc957367d8e2b73c37733ff1dc376c1 (diff)
Tests for metrics correctness with TPU strategy
PiperOrigin-RevId: 215618809
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD17
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py121
3 files changed, 86 insertions, 56 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index defa82f98a..8267612236 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -737,18 +737,27 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "metrics_v1_test",
+py_library(
+ name = "metrics_v1_test_lib",
+ testonly = 1,
srcs = ["metrics_v1_test.py"],
- additional_deps = [
+ deps = [
":combinations",
- "@absl_py//absl/testing:parameterized",
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "metrics_v1_test",
+ srcs = ["metrics_v1_test.py"],
+ additional_deps = [
+ ":metrics_v1_test_lib",
],
tags = [
"multi_and_single_gpu",
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 82ca041cc2..cff4b0a463 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -329,10 +329,10 @@ one_device_strategy = NamedDistribution(
required_gpus=None)
tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
- TPUClusterResolver(""), steps_per_run=5),
+ TPUClusterResolver(""), steps_per_run=2),
required_tpu=True)
tpu_strategy_one_step = NamedDistribution(
- "TPU", lambda: tpu_lib.TPUStrategy(
+ "TPUOneStep", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=1),
required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 8163494c8e..ae4189eb1c 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
@@ -35,7 +36,8 @@ def _labeled_dataset_fn():
# 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
# 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
return dataset_ops.Dataset.range(1000).map(
- lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
+ lambda x: {"labels": x % 5, "predictions": x % 3}).batch(
+ 4, drop_remainder=True)
def _boolean_dataset_fn():
@@ -47,7 +49,8 @@ def _boolean_dataset_fn():
# F, T -> FP; T, F -> FN; F, F -> TN
return dataset_ops.Dataset.from_tensor_slices({
"labels": [True, False, True, False],
- "predictions": [True, True, False, False]}).repeat().batch(3)
+ "predictions": [True, True, False, False]}).repeat().batch(
+ 3, drop_remainder=True)
def _threshold_dataset_fn():
@@ -59,7 +62,8 @@ def _threshold_dataset_fn():
# False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
return dataset_ops.Dataset.from_tensor_slices({
"labels": [True, False, True, False],
- "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
+ "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(
+ 3, drop_remainder=True)
def _regression_dataset_fn():
@@ -79,6 +83,12 @@ def all_combinations():
mode=["graph"])
+def tpu_combinations():
+ return combinations.combine(distribution=[combinations.tpu_strategy_one_step,
+ combinations.tpu_strategy],
+ mode=["graph"])
+
+
# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
# metrics.precision_at_k
class MetricsV1Test(test.TestCase, parameterized.TestCase):
@@ -87,42 +97,50 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
- value, update = distribution.call_for_each_tower(
- metric_fn, iterator.get_next())
- update = distribution.group(update)
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ def step_fn(ctx, inputs):
+ value, update = distribution.call_for_each_tower(
+ metric_fn, inputs)
+ ctx.set_non_tensor_output(name="value", output=value)
+ return distribution.group(update)
+
+ ctx = distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=distribution.steps_per_run)
+ update = ctx.run_op
+ value = ctx.non_tensor_outputs["value"]
+ # In each run, we run multiple steps, and each steps consumes as many
+ # batches as number of towers.
+ batches_per_update = (
+ distribution.num_towers * distribution.steps_per_run)
+ else:
+ value, update = distribution.call_for_each_tower(
+ metric_fn, iterator.get_next())
+ update = distribution.group(update)
+ # TODO(josh11b): Once we switch to using a global batch size for input,
+ # replace "distribution.num_towers" with "1".
+ batches_per_update = distribution.num_towers
+
+ self.evaluate(distribution.initialize())
self.evaluate(variables.local_variables_initializer())
- # TODO(josh11b): Once we switch to using a global batch size for input,
- # replace "distribution.num_towers" with "1".
- batches_per_update = distribution.num_towers
-
- # Update variables using the first `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
- 0.001, msg="After first update")
-
- # Update variables using the second `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(2 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After second update")
-
- if batches_per_update == 1: # Consume 4 input batches
- self.evaluate(update)
- self.assertAllClose(expected_fn(3 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After third update")
+
+ batches_consumed = 0
+ for i in range(4):
self.evaluate(update)
- self.assertAllClose(expected_fn(4 * batches_per_update),
+ batches_consumed += batches_per_update
+ self.assertAllClose(expected_fn(batches_consumed),
self.evaluate(value),
0.001,
- msg="After fourth update")
+ msg="After update #" + str(i+1))
+ if batches_consumed >= 4: # Consume 4 input batches in total.
+ break
- @combinations.generate(all_combinations())
+ self.evaluate(distribution.finalize())
+
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMean(self, distribution):
def _dataset_fn():
- return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
+ return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(
+ 4, drop_remainder=True)
def _expected_fn(num_batches):
# Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
@@ -130,7 +148,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAccuracy(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -143,6 +161,8 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+ # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added
+ # for TPUMirroredVariable.
@combinations.generate(all_combinations())
def testMeanPerClassAccuracy(self, distribution):
def _metric_fn(x):
@@ -161,6 +181,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+ # NOTE(priyag): This metric doesn't work on TPUs yet.
@combinations.generate(all_combinations())
def testMeanIOU(self, distribution):
def _metric_fn(x):
@@ -179,7 +200,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMeanTensor(self, distribution):
def _dataset_fn():
dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
@@ -198,7 +219,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAUCROC(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -212,7 +233,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAUCPR(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -226,7 +247,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalseNegatives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -239,7 +260,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalseNegativesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -252,7 +273,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTrueNegatives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -265,7 +286,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTrueNegativesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -278,7 +299,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalsePositives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -291,7 +312,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalsePositivesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -304,7 +325,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTruePositives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -317,7 +338,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTruePositivesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -330,7 +351,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testPrecision(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -343,7 +364,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testPrecisionAtThreshold(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -356,7 +377,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRecall(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -369,7 +390,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRecallAtThreshold(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -382,7 +403,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMeanSquaredError(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -395,7 +416,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRootMeanSquaredError(self, distribution):
def _metric_fn(x):
labels = x["labels"]