aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:37:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:04:14 -0700
commitb828f89263e054bfa7c7a808cab1506834ab906d (patch)
treee31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/contrib/gan
parentacf0ee82092727afc2067316982407cf5e496f75 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py52
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py8
2 files changed, 30 insertions, 30 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 9f5fee4542..e3c780ac1a 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -51,7 +51,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -59,7 +59,7 @@ class _LossesTest(object):
self._discriminator_real_outputs, self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -90,7 +90,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
@@ -98,7 +98,7 @@ class _LossesTest(object):
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -108,7 +108,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(logits, weights=weights)
self.assertEqual(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [[10.0, 4.4, -5.5, 3.6]],
@@ -125,7 +125,7 @@ class _LossesTest(object):
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
@@ -136,7 +136,7 @@ class _LossesTest(object):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -144,14 +144,14 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=self._weights, generated_weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs,
weights=constant_op.constant(self._weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -160,7 +160,7 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=weights, generated_weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._generator_kwargs.items()}
loss = self._g_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._discriminator_kwargs.items()}
loss = self._d_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase):
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._g_loss_fn(gen_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(
real_weights=self._weights, generated_weights=self._weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(
weights=constant_op.constant(self._weights), **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -404,7 +404,7 @@ class _PenaltyTest(object):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
self.assertEqual(self._expected_op_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
@@ -419,13 +419,13 @@ class _PenaltyTest(object):
def test_python_scalar_weight(self):
loss = self._penalty_fn(weights=2.3, **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
def test_scalar_tensor_weight(self):
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
@@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'])
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
one_sided=True)
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'],
target=2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(
loss,
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbfa11..25d74a8c23 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args):
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
@@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase):
discriminator_scope=self.discriminator_scope)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])