diff options
author | Allen Lavoie <allenl@google.com> | 2018-01-29 09:44:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-29 09:48:39 -0800 |
commit | 219e22879ba981aa33fbe8f54a550cce56bc5d90 (patch) | |
tree | e360567add3e037d85ead98f41d22dff0a18d074 | |
parent | b9dae47061d4d4c9b8f8a79e73519525413ab84c (diff) |
TFTS: Remove a race condition in lstm_test (switch to resource variables)
PiperOrigin-RevId: 183679060
-rw-r--r-- | tensorflow/contrib/timeseries/examples/lstm.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index c7193cef69..c834430b95 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from os import path import numpy @@ -80,18 +81,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): input_statistics: A math_utils.InputStatistics object. """ super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics) - self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) - # Create templates so we don't have to worry about variable reuse. - self._lstm_cell_run = tf.make_template( - name_="lstm_cell", - func_=self._lstm_cell, - create_scope_now_=True) - # Transforms LSTM output into mean predictions. - self._predict_from_lstm_output = tf.make_template( - name_="predict_from_lstm_output", - func_= - lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features), - create_scope_now_=True) + with tf.variable_scope("", use_resource=True): + # Use ResourceVariables to avoid race conditions. + self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) + # Create templates so we don't have to worry about variable reuse. + self._lstm_cell_run = tf.make_template( + name_="lstm_cell", + func_=self._lstm_cell, + create_scope_now_=True) + # Transforms LSTM output into mean predictions. + self._predict_from_lstm_output = tf.make_template( + name_="predict_from_lstm_output", + func_=functools.partial(tf.layers.dense, units=self.num_features), + create_scope_now_=True) def get_start_state(self): """Return initial state for the time series model.""" |