diff options
author | 2018-09-10 14:37:06 -0700 | |
---|---|---|
committer | 2018-09-10 15:04:14 -0700 | |
commit | b828f89263e054bfa7c7a808cab1506834ab906d (patch) | |
tree | e31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/contrib/gan | |
parent | acf0ee82092727afc2067316982407cf5e496f75 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/losses_impl_test.py | 52 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py | 8 |
2 files changed, 30 insertions, 30 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 9f5fee4542..e3c780ac1a 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -51,7 +51,7 @@ class _LossesTest(object): loss = self._g_loss_fn(self._discriminator_gen_outputs) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) self.assertEqual(self._generator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_all_correct(self): @@ -59,7 +59,7 @@ class _LossesTest(object): self._discriminator_real_outputs, self._discriminator_gen_outputs) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_collection(self): @@ -90,7 +90,7 @@ class _LossesTest(object): loss = self._g_loss_fn( array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_patch(self): @@ -98,7 +98,7 @@ class _LossesTest(object): array_ops.reshape(self._discriminator_real_outputs, [2, 2]), array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_with_placeholder_for_logits(self): @@ -108,7 +108,7 @@ class _LossesTest(object): loss = self._g_loss_fn(logits, weights=weights) self.assertEqual(logits.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: [[10.0, 4.4, -5.5, 3.6]], @@ -125,7 +125,7 @@ class _LossesTest(object): logits, logits2, real_weights=real_weights, generated_weights=generated_weights) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: [self._discriminator_real_outputs_np], @@ -136,7 +136,7 @@ class _LossesTest(object): def test_generator_with_python_scalar_weight(self): loss = self._g_loss_fn( self._discriminator_gen_outputs, weights=self._weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -144,14 +144,14 @@ class _LossesTest(object): loss = self._d_loss_fn( self._discriminator_real_outputs, self._discriminator_gen_outputs, real_weights=self._weights, generated_weights=self._weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) def test_generator_with_scalar_tensor_weight(self): loss = self._g_loss_fn(self._discriminator_gen_outputs, weights=constant_op.constant(self._weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -160,7 +160,7 @@ class _LossesTest(object): loss = self._d_loss_fn( self._discriminator_real_outputs, self._discriminator_gen_outputs, real_weights=weights, generated_weights=weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) @@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase): self.assertEqual( self._discriminator_gen_classification_logits.dtype, loss.dtype) self.assertEqual(self._generator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_all_correct(self): @@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase): self.assertEqual( self._discriminator_gen_classification_logits.dtype, loss.dtype) self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_collection(self): @@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase): patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in self._generator_kwargs.items()} loss = self._g_loss_fn(**patch_args) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_patch(self): patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in self._discriminator_kwargs.items()} loss = self._d_loss_fn(**patch_args) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_with_placeholder_for_logits(self): @@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase): one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) loss = self._g_loss_fn(gen_logits, one_hot_labels) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run( loss, feed_dict={ gen_logits: self._discriminator_gen_classification_logits_np, @@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase): loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run( loss, feed_dict={ gen_logits: self._discriminator_gen_classification_logits_np, @@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase): def test_generator_with_python_scalar_weight(self): loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase): loss = self._d_loss_fn( real_weights=self._weights, generated_weights=self._weights, **self._discriminator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) def test_generator_with_scalar_tensor_weight(self): loss = self._g_loss_fn( weights=constant_op.constant(self._weights), **self._generator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase): weights = constant_op.constant(self._weights) loss = self._d_loss_fn(real_weights=weights, generated_weights=weights, **self._discriminator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) @@ -404,7 +404,7 @@ class _PenaltyTest(object): loss = self._penalty_fn(**self._kwargs) self.assertEqual(self._expected_dtype, loss.dtype) self.assertEqual(self._expected_op_name, loss.op.name) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) @@ -419,13 +419,13 @@ class _PenaltyTest(object): def test_python_scalar_weight(self): loss = self._penalty_fn(weights=2.3, **self._kwargs) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) def test_scalar_tensor_weight(self): loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) @@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self._kwargs['discriminator_scope']) self.assertEqual(generated_data.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ @@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): one_sided=True) self.assertEqual(generated_data.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ @@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self._kwargs['discriminator_scope'], target=2.0) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run( loss, diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index a559bbfa11..25d74a8c23 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args): def consistency_test(self): self.assertEqual(arg_loss.__name__, tuple_loss.__name__) - with self.test_session(): + with self.cached_session(): self.assertEqual(arg_loss(**loss_args).eval(), tuple_loss(_tuple_from_dict(loss_args)).eval()) @@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase): discriminator_scope=self.discriminator_scope) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) |