aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2017-12-01 10:15:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 10:20:35 -0800
commitcccddc83d527caaeefc86577c234e5dfd13b4979 (patch)
tree3b89ff5071e8b30ac71249f81b149112b01fb558
parent2c78d7bfaf3158df22401c03fc4de2cb99526d4f (diff)
Seed the time series LSTM example unit test.
PiperOrigin-RevId: 177606245
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py5
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm_test.py11
3 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 755b0657e9..bb86ecb220 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -103,6 +103,7 @@ py_test(
deps = [
":lstm",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index 3ba823f638..c7193cef69 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -165,12 +165,13 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
"Exogenous inputs are not implemented for this example.")
-def train_and_predict(csv_file_name=_DATA_FILE, training_steps=200):
+def train_and_predict(
+ csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None):
"""Train and predict using a custom time series model."""
# Construct an Estimator from our LSTM model.
estimator = ts_estimators.TimeSeriesRegressor(
model=_LSTMModel(num_features=5, num_units=128),
- optimizer=tf.train.AdamOptimizer(0.001))
+ optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config)
reader = tf.contrib.timeseries.CSVReader(
csv_file_name,
column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
diff --git a/tensorflow/contrib/timeseries/examples/lstm_test.py b/tensorflow/contrib/timeseries/examples/lstm_test.py
index 56daa1e10d..3cace56726 100644
--- a/tensorflow/contrib/timeseries/examples/lstm_test.py
+++ b/tensorflow/contrib/timeseries/examples/lstm_test.py
@@ -20,14 +20,23 @@ from __future__ import print_function
from tensorflow.contrib.timeseries.examples import lstm
+from tensorflow.python.estimator import estimator_lib
from tensorflow.python.platform import test
+class _SeedRunConfig(estimator_lib.RunConfig):
+
+ @property
+ def tf_random_seed(self):
+ return 3
+
+
class LSTMExampleTest(test.TestCase):
def test_periodicity_learned(self):
(observed_times, observed_values,
- all_times, predicted_values) = lstm.train_and_predict(training_steps=100)
+ all_times, predicted_values) = lstm.train_and_predict(
+ training_steps=100, estimator_config=_SeedRunConfig())
self.assertAllEqual([100], observed_times.shape)
self.assertAllEqual([100, 5], observed_values.shape)
self.assertAllEqual([200], all_times.shape)