diff options
author | 2017-12-01 10:15:11 -0800 | |
---|---|---|
committer | 2017-12-01 10:20:35 -0800 | |
commit | cccddc83d527caaeefc86577c234e5dfd13b4979 (patch) | |
tree | 3b89ff5071e8b30ac71249f81b149112b01fb558 | |
parent | 2c78d7bfaf3158df22401c03fc4de2cb99526d4f (diff) |
Seed the time series LSTM example unit test.
PiperOrigin-RevId: 177606245
-rw-r--r-- | tensorflow/contrib/timeseries/examples/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/timeseries/examples/lstm.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/timeseries/examples/lstm_test.py | 11 |
3 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 755b0657e9..bb86ecb220 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -103,6 +103,7 @@ py_test( deps = [ ":lstm", "//tensorflow/python:client_testlib", + "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index 3ba823f638..c7193cef69 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -165,12 +165,13 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): "Exogenous inputs are not implemented for this example.") -def train_and_predict(csv_file_name=_DATA_FILE, training_steps=200): +def train_and_predict( + csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None): """Train and predict using a custom time series model.""" # Construct an Estimator from our LSTM model. estimator = ts_estimators.TimeSeriesRegressor( model=_LSTMModel(num_features=5, num_units=128), - optimizer=tf.train.AdamOptimizer(0.001)) + optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config) reader = tf.contrib.timeseries.CSVReader( csv_file_name, column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) diff --git a/tensorflow/contrib/timeseries/examples/lstm_test.py b/tensorflow/contrib/timeseries/examples/lstm_test.py index 56daa1e10d..3cace56726 100644 --- a/tensorflow/contrib/timeseries/examples/lstm_test.py +++ b/tensorflow/contrib/timeseries/examples/lstm_test.py @@ -20,14 +20,23 @@ from __future__ import print_function from tensorflow.contrib.timeseries.examples import lstm +from tensorflow.python.estimator import estimator_lib from tensorflow.python.platform import test +class _SeedRunConfig(estimator_lib.RunConfig): + + @property + def tf_random_seed(self): + return 3 + + class LSTMExampleTest(test.TestCase): def test_periodicity_learned(self): (observed_times, observed_values, - all_times, predicted_values) = lstm.train_and_predict(training_steps=100) + all_times, predicted_values) = lstm.train_and_predict( + training_steps=100, estimator_config=_SeedRunConfig()) self.assertAllEqual([100], observed_times.shape) self.assertAllEqual([100, 5], observed_values.shape) self.assertAllEqual([200], all_times.shape) |