aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-10 13:59:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 14:02:32 -0700
commit4995231f9e383b4edc222f63f546b9fa8577fb69 (patch)
tree2cad7e04bd8167836584d90cdbb7e42b6f9730fd /tensorflow/contrib/gan
parent0932d4af60cd8c9ce322a8e16c8f51d300eb4402 (diff)
test previously untested eval codepaths.
PiperOrigin-RevId: 192341561
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD1
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py33
2 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 461066bbb4..b305f37791 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -364,6 +364,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
index 663e49bdca..4fb8d58bc9 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -22,6 +22,7 @@ import os
import tarfile
import tempfile
+from absl.testing import parameterized
import numpy as np
from scipy import linalg as scp_linalg
@@ -182,13 +183,20 @@ def _run_with_mock(function, *args, **kwargs):
return function(*args, **kwargs)
-class ClassifierMetricsTest(test.TestCase):
+class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
- def test_run_inception_graph(self):
+ @parameterized.named_parameters(
+ ('GraphDef', False),
+ ('DefaultGraphDefFn', True))
+ def test_run_inception_graph(self, use_default_graph_def):
"""Test `run_inception` graph construction."""
batch_size = 7
img = array_ops.ones([batch_size, 299, 299, 3])
- logits = _run_with_mock(classifier_metrics.run_inception, img)
+
+ if use_default_graph_def:
+ logits = _run_with_mock(classifier_metrics.run_inception, img)
+ else:
+ logits = classifier_metrics.run_inception(img, _get_dummy_graphdef())
self.assertTrue(isinstance(logits, ops.Tensor))
logits.shape.assert_is_compatible_with([batch_size, 1001])
@@ -196,14 +204,23 @@ class ClassifierMetricsTest(test.TestCase):
# Check that none of the model variables are trainable.
self.assertListEqual([], variables.trainable_variables())
- def test_run_inception_graph_pool_output(self):
+ @parameterized.named_parameters(
+ ('GraphDef', False),
+ ('DefaultGraphDefFn', True))
+ def test_run_inception_graph_pool_output(self, use_default_graph_def):
"""Test `run_inception` graph construction with pool output."""
batch_size = 3
img = array_ops.ones([batch_size, 299, 299, 3])
- pool = _run_with_mock(
- classifier_metrics.run_inception,
- img,
- output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
+
+ if use_default_graph_def:
+ pool = _run_with_mock(
+ classifier_metrics.run_inception,
+ img,
+ output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
+ else:
+ pool = classifier_metrics.run_inception(
+ img, _get_dummy_graphdef(),
+ output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
self.assertTrue(isinstance(pool, ops.Tensor))
pool.shape.assert_is_compatible_with([batch_size, 2048])