diff options
Diffstat (limited to 'tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py')
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py index c2eaa78493..80126ac786 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py @@ -96,7 +96,7 @@ class ConstructionTests(test.TestCase): }, mode=estimator_lib.ModeKeys.TRAIN) initializer = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([initializer]) outputs.loss.eval() @@ -114,7 +114,7 @@ class ConstructionTests(test.TestCase): }, mode=estimator_lib.ModeKeys.TRAIN) initializer = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([initializer]) outputs.loss.eval() @@ -144,7 +144,7 @@ class GapTests(test.TestCase): state=math_utils.replicate_state( start_state=random_model.get_start_state(), batch_size=array_ops.shape(times)[0])) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) @@ -250,7 +250,7 @@ class StateSpaceEquivalenceTests(test.TestCase): self.assertAllClose(combined_value, split_predict[prediction_key]) def _equivalent_to_single_model_test_template(self, model_generator): - with self.test_session() as session: + with self.cached_session() as session: random_model = RandomStateSpaceModel( state_dimension=5, state_noise_dimension=4, @@ -374,7 +374,7 @@ class PredictionTests(test.TestCase): math_utils.replicate_state( start_state=random_model.get_start_state(), batch_size=1) }) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() predicted_mean = prediction_dict["mean"].eval() predicted_covariance = prediction_dict["covariance"].eval() @@ -404,7 +404,7 @@ class PredictionTests(test.TestCase): feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]], feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state }) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() predicted_mean = predictions["mean"].eval() predicted_covariance = predictions["covariance"].eval() @@ -428,7 +428,7 @@ class ExogenousTests(test.TestCase): state=[ array_ops.ones(shape=[1, 5]), original_covariance[None], [0] ]) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() evaled_new_covariance, evaled_original_covariance = session.run( [new_covariance[0], original_covariance]) @@ -454,7 +454,7 @@ class ExogenousTests(test.TestCase): -array_ops.ones(shape=[1, 5], dtype=dtype), original_covariance[None], [0] ]) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() evaled_new_covariance, evaled_original_covariance = session.run( [new_covariance[0], original_covariance]) @@ -519,7 +519,7 @@ class PosteriorTests(test.TestCase): model=stub_model, data=data, true_parameters=true_params) def test_exact_posterior_recovery_no_transition_noise(self): - with self.test_session() as session: + with self.cached_session() as session: stub_model, data, true_params = self._get_single_model() input_fn = input_pipeline.WholeDatasetInputFn( input_pipeline.NumpyReader(data)) @@ -559,7 +559,7 @@ class PosteriorTests(test.TestCase): posterior_times) def test_chained_exact_posterior_recovery_no_transition_noise(self): - with self.test_session() as session: + with self.cached_session() as session: stub_model, data, true_params = self._get_single_model() chunk_size = 10 input_fn = test_utils.AllWindowInputFn( @@ -748,7 +748,7 @@ class MultivariateTests(test.TestCase): }, mode=estimator_lib.ModeKeys.TRAIN) initializer = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([initializer]) outputs.loss.eval() |