diff options
Diffstat (limited to 'tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py')
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index a559bbfa11..25d74a8c23 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args): def consistency_test(self): self.assertEqual(arg_loss.__name__, tuple_loss.__name__) - with self.test_session(): + with self.cached_session(): self.assertEqual(arg_loss(**loss_args).eval(), tuple_loss(_tuple_from_dict(loss_args)).eval()) @@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase): discriminator_scope=self.discriminator_scope) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) |