aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py')
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbfa11..25d74a8c23 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args):
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
@@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase):
discriminator_scope=self.discriminator_scope)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])