aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:24:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 19:27:54 -0700
commit496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (patch)
tree9776c9865f7b98a15817bc6be4c2b683323a67b1 /tensorflow/contrib/gan
parent361a82d73a50a800510674b3aaa20e4845e56434 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209701635
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py10
-rw-r--r--tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py6
2 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
index 4fb8d58bc9..d64dfd1576 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -335,7 +335,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
tf_pool_real_a, tf_pool_gen_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_mofid = sess.run(mofid_op)
expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a)
@@ -355,7 +355,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
tf_pool_real_a, tf_pool_gen_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_dofid = sess.run(dofid_op)
expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a)
@@ -377,7 +377,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
test_pool_gen_a,
classifier_fn=lambda x: x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_fid = sess.run(fid_op)
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
@@ -404,7 +404,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
classifier_fn=lambda x: x))
fids = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for fid_op in fid_ops:
fids.append(sess.run(fid_op))
@@ -426,7 +426,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
trace_sqrt_prod_op = _run_with_mock(classifier_metrics.trace_sqrt_product,
cov_real, cov_gen)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# trace_sqrt_product: tsp
actual_tsp = sess.run(trace_sqrt_prod_op)
diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
index 871f1ad54e..ab909feae3 100644
--- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
@@ -65,7 +65,7 @@ class ClassifierMetricsTest(test.TestCase):
pyramid = np_laplacian_pyramid(data, 3)
data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3])
pyramid_tf = swd._laplacian_pyramid(data_tf, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pyramid_tf = sess.run(
pyramid_tf, feed_dict={
data_tf: data.transpose(0, 2, 3, 1)
@@ -79,7 +79,7 @@ class ClassifierMetricsTest(test.TestCase):
d1 = random_ops.random_uniform([256, 32, 32, 3])
d2 = random_ops.random_normal([256, 32, 32, 3])
wfunc = swd.sliced_wasserstein_distance(d1, d2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wscores = [sess.run(x) for x in wfunc]
self.assertAllClose(
np.array([0.014, 0.014], 'f'),
@@ -95,7 +95,7 @@ class ClassifierMetricsTest(test.TestCase):
d1 = random_ops.random_uniform([256, 32, 32, 3])
d2 = random_ops.random_normal([256, 32, 32, 3])
wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wscores = [sess.run(x) for x in wfunc]
self.assertAllClose(
np.array([0.013, 0.013], 'f'),