diff options
author | 2018-04-10 13:59:49 -0700 | |
---|---|---|
committer | 2018-04-10 14:02:32 -0700 | |
commit | 4995231f9e383b4edc222f63f546b9fa8577fb69 (patch) | |
tree | 2cad7e04bd8167836584d90cdbb7e42b6f9730fd /tensorflow/contrib/gan | |
parent | 0932d4af60cd8c9ce322a8e16c8f51d300eb4402 (diff) |
test previously untested eval codepaths.
PiperOrigin-RevId: 192341561
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py | 33 |
2 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 461066bbb4..b305f37791 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -364,6 +364,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:variables", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 663e49bdca..4fb8d58bc9 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -22,6 +22,7 @@ import os import tarfile import tempfile +from absl.testing import parameterized import numpy as np from scipy import linalg as scp_linalg @@ -182,13 +183,20 @@ def _run_with_mock(function, *args, **kwargs): return function(*args, **kwargs) -class ClassifierMetricsTest(test.TestCase): +class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): - def test_run_inception_graph(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph(self, use_default_graph_def): """Test `run_inception` graph construction.""" batch_size = 7 img = array_ops.ones([batch_size, 299, 299, 3]) - logits = _run_with_mock(classifier_metrics.run_inception, img) + + if use_default_graph_def: + logits = _run_with_mock(classifier_metrics.run_inception, img) + else: + logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) self.assertTrue(isinstance(logits, ops.Tensor)) logits.shape.assert_is_compatible_with([batch_size, 1001]) @@ -196,14 +204,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) - def test_run_inception_graph_pool_output(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph_pool_output(self, use_default_graph_def): """Test `run_inception` graph construction with pool output.""" batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) - pool = _run_with_mock( - classifier_metrics.run_inception, - img, - output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + + if use_default_graph_def: + pool = _run_with_mock( + classifier_metrics.run_inception, + img, + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + else: + pool = classifier_metrics.run_inception( + img, _get_dummy_graphdef(), + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) pool.shape.assert_is_compatible_with([batch_size, 2048]) |