aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
index ff6092fc26..faff42d243 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
@@ -35,7 +35,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
test.TestCase):
def testSampleAndLogProbUnivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=normal_lib.Normal(
@@ -46,7 +46,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5], log_prob_x.shape)
def testSampleAndLogProbBatch(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]),
components_distribution=normal_lib.Normal(
@@ -59,7 +59,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
def testSampleAndLogProbShapesBroadcastMix(self):
mix_probs = np.float32([.3, .7])
bern_probs = np.float32([[.4, .6], [.25, .75]])
- with self.test_session():
+ with self.cached_session():
bm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=mix_probs),
components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs))
@@ -72,7 +72,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.))
def testSampleAndLogProbMultivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -83,7 +83,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5], log_prob_x.shape)
def testSampleAndLogProbBatchMultivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -98,7 +98,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5, 2], log_prob_x.shape)
def testSampleConsistentLogProb(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -111,7 +111,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
def testLogCdf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=normal_lib.Normal(
@@ -128,7 +128,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
rtol=1e-6, atol=0.0)
def testSampleConsistentMeanCovariance(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -136,7 +136,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.run_test_sample_consistent_mean_covariance(sess.run, gm)
def testVarianceConsistentCovariance(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(