aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index c2eaa78493..80126ac786 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -96,7 +96,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -114,7 +114,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -144,7 +144,7 @@ class GapTests(test.TestCase):
state=math_utils.replicate_state(
start_state=random_model.get_start_state(),
batch_size=array_ops.shape(times)[0]))
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -250,7 +250,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
self.assertAllClose(combined_value, split_predict[prediction_key])
def _equivalent_to_single_model_test_template(self, model_generator):
- with self.test_session() as session:
+ with self.cached_session() as session:
random_model = RandomStateSpaceModel(
state_dimension=5,
state_noise_dimension=4,
@@ -374,7 +374,7 @@ class PredictionTests(test.TestCase):
math_utils.replicate_state(
start_state=random_model.get_start_state(), batch_size=1)
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = prediction_dict["mean"].eval()
predicted_covariance = prediction_dict["covariance"].eval()
@@ -404,7 +404,7 @@ class PredictionTests(test.TestCase):
feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = predictions["mean"].eval()
predicted_covariance = predictions["covariance"].eval()
@@ -428,7 +428,7 @@ class ExogenousTests(test.TestCase):
state=[
array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -454,7 +454,7 @@ class ExogenousTests(test.TestCase):
-array_ops.ones(shape=[1, 5], dtype=dtype),
original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -519,7 +519,7 @@ class PosteriorTests(test.TestCase):
model=stub_model, data=data, true_parameters=true_params)
def test_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(data))
@@ -559,7 +559,7 @@ class PosteriorTests(test.TestCase):
posterior_times)
def test_chained_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
chunk_size = 10
input_fn = test_utils.AllWindowInputFn(
@@ -748,7 +748,7 @@ class MultivariateTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()