diff options
author | Priya Gupta <priyag@google.com> | 2018-10-03 13:39:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 13:50:49 -0700 |
commit | ce9a5d143f89a37ab029a29c62433883323987e8 (patch) | |
tree | 0953913adeb36a4173bbb5b4e582c925bcab875d /tensorflow/contrib | |
parent | 43073e9d4dc957367d8e2b73c37733ff1dc376c1 (diff) |
Tests for metrics correctness with TPU strategy
PiperOrigin-RevId: 215618809
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/distribute/python/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/combinations.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/metrics_v1_test.py | 121 |
3 files changed, 86 insertions, 56 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index defa82f98a..8267612236 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -737,18 +737,27 @@ cuda_py_test( ], ) -cuda_py_test( - name = "metrics_v1_test", +py_library( + name = "metrics_v1_test_lib", + testonly = 1, srcs = ["metrics_v1_test.py"], - additional_deps = [ + deps = [ ":combinations", - "@absl_py//absl/testing:parameterized", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":metrics_v1_test_lib", ], tags = [ "multi_and_single_gpu", diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 82ca041cc2..cff4b0a463 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -329,10 +329,10 @@ one_device_strategy = NamedDistribution( required_gpus=None) tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=5), + TPUClusterResolver(""), steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( + "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) # Note that we disable prefetching for testing since prefetching makes diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8163494c8e..ae4189eb1c 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -35,7 +36,8 @@ def _labeled_dataset_fn(): # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True return dataset_ops.Dataset.range(1000).map( - lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + lambda x: {"labels": x % 5, "predictions": x % 3}).batch( + 4, drop_remainder=True) def _boolean_dataset_fn(): @@ -47,7 +49,8 @@ def _boolean_dataset_fn(): # F, T -> FP; T, F -> FN; F, F -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [True, True, False, False]}).repeat().batch(3) + "predictions": [True, True, False, False]}).repeat().batch( + 3, drop_remainder=True) def _threshold_dataset_fn(): @@ -59,7 +62,8 @@ def _threshold_dataset_fn(): # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch( + 3, drop_remainder=True) def _regression_dataset_fn(): @@ -79,6 +83,12 @@ def all_combinations(): mode=["graph"]) +def tpu_combinations(): + return combinations.combine(distribution=[combinations.tpu_strategy_one_step, + combinations.tpu_strategy], + mode=["graph"]) + + # TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, # metrics.precision_at_k class MetricsV1Test(test.TestCase, parameterized.TestCase): @@ -87,42 +97,50 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() - value, update = distribution.call_for_each_tower( - metric_fn, iterator.get_next()) - update = distribution.group(update) + if isinstance(distribution, tpu_strategy.TPUStrategy): + def step_fn(ctx, inputs): + value, update = distribution.call_for_each_tower( + metric_fn, inputs) + ctx.set_non_tensor_output(name="value", output=value) + return distribution.group(update) + + ctx = distribution.run_steps_on_dataset( + step_fn, iterator, iterations=distribution.steps_per_run) + update = ctx.run_op + value = ctx.non_tensor_outputs["value"] + # In each run, we run multiple steps, and each steps consumes as many + # batches as number of towers. + batches_per_update = ( + distribution.num_towers * distribution.steps_per_run) + else: + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) - # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_towers" with "1". - batches_per_update = distribution.num_towers - - # Update variables using the first `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), - 0.001, msg="After first update") - - # Update variables using the second `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(2 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After second update") - - if batches_per_update == 1: # Consume 4 input batches - self.evaluate(update) - self.assertAllClose(expected_fn(3 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After third update") + + batches_consumed = 0 + for i in range(4): self.evaluate(update) - self.assertAllClose(expected_fn(4 * batches_per_update), + batches_consumed += batches_per_update + self.assertAllClose(expected_fn(batches_consumed), self.evaluate(value), 0.001, - msg="After fourth update") + msg="After update #" + str(i+1)) + if batches_consumed >= 4: # Consume 4 input batches in total. + break - @combinations.generate(all_combinations()) + self.evaluate(distribution.finalize()) + + @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): - return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch( + 4, drop_remainder=True) def _expected_fn(num_batches): # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. @@ -130,7 +148,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAccuracy(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -143,6 +161,8 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added + # for TPUMirroredVariable. @combinations.generate(all_combinations()) def testMeanPerClassAccuracy(self, distribution): def _metric_fn(x): @@ -161,6 +181,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # NOTE(priyag): This metric doesn't work on TPUs yet. @combinations.generate(all_combinations()) def testMeanIOU(self, distribution): def _metric_fn(x): @@ -179,7 +200,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanTensor(self, distribution): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) @@ -198,7 +219,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCROC(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -212,7 +233,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCPR(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -226,7 +247,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -239,7 +260,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -252,7 +273,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -265,7 +286,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -278,7 +299,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -291,7 +312,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -304,7 +325,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -317,7 +338,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -330,7 +351,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecision(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -343,7 +364,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecisionAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -356,7 +377,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecall(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -369,7 +390,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecallAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -382,7 +403,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -395,7 +416,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _regression_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRootMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] |