diff options
Diffstat (limited to 'tensorflow/contrib/timeseries/python/timeseries/head_test.py')
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/head_test.py | 56 |
1 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index ed8f29c321..78c2cec21c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +from absl.testing import parameterized import numpy import six @@ -317,10 +320,38 @@ class PredictFeatureCheckingTests(test.TestCase): mode=estimator_lib.ModeKeys.PREDICT) -class OneShotTests(test.TestCase): - - def test_one_shot_prediction_head_export(self): - model_dir = self.get_temp_dir() +def _custom_time_series_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel( + num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + state_manager=state_management.ChainingStateManager(), + head_type=head_type, + model_dir=model_dir) + + +def _structural_ensemble_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.StructuralEnsembleRegressor( + periodicities=None, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + head_type=head_type, + model_dir=model_dir) + + +class OneShotTests(parameterized.TestCase): + + @parameterized.named_parameters( + {"testcase_name": "custom_time_series_regressor", + "estimator_factory": _custom_time_series_regressor}, + {"testcase_name": "structural_ensemble_regressor", + "estimator_factory": _structural_ensemble_regressor}) + def test_one_shot_prediction_head_export(self, estimator_factory): + model_dir = os.path.join(test.get_temp_dir(), str(ops.uid())) categorical_column = feature_column.categorical_column_with_hash_bucket( key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ @@ -328,15 +359,10 @@ class OneShotTests(test.TestCase): "2d_exogenous_feature", shape=(2,)), feature_column.embedding_column( categorical_column=categorical_column, dimension=10)] - estimator = ts_estimators.TimeSeriesRegressor( - model=lstm_example._LSTMModel( - num_features=5, num_units=128, - exogenous_feature_columns=exogenous_feature_columns), - optimizer=adam.AdamOptimizer(0.001), - config=estimator_lib.RunConfig(tf_random_seed=4), - state_manager=state_management.ChainingStateManager(), - head_type=ts_head_lib.OneShotPredictionHead, - model_dir=model_dir) + estimator = estimator_factory( + model_dir=model_dir, + exogenous_feature_columns=exogenous_feature_columns, + head_type=ts_head_lib.OneShotPredictionHead) train_features = { feature_keys.TrainEvalFeatures.TIMES: numpy.arange( 20, dtype=numpy.int64), @@ -351,7 +377,7 @@ class OneShotTests(test.TestCase): num_threads=1, batch_size=16, window_size=16) estimator.train(input_fn=train_input_fn, steps=5) input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() - export_location = estimator.export_savedmodel(self.get_temp_dir(), + export_location = estimator.export_savedmodel(test.get_temp_dir(), input_receiver_fn) graph = ops.Graph() with graph.as_default(): @@ -385,7 +411,7 @@ class OneShotTests(test.TestCase): for output_key, output_value in predict_signature.outputs.items()} output = session.run(fetches, feed_dict=feeds) - self.assertAllEqual((2, 15, 5), output["mean"].shape) + self.assertEqual((2, 15, 5), output["mean"].shape) if __name__ == "__main__": |