diff options
-rw-r--r-- | tensorflow/contrib/timeseries/examples/lstm.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index c7193cef69..c834430b95 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from os import path import numpy @@ -80,18 +81,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): input_statistics: A math_utils.InputStatistics object. """ super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics) - self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) - # Create templates so we don't have to worry about variable reuse. - self._lstm_cell_run = tf.make_template( - name_="lstm_cell", - func_=self._lstm_cell, - create_scope_now_=True) - # Transforms LSTM output into mean predictions. - self._predict_from_lstm_output = tf.make_template( - name_="predict_from_lstm_output", - func_= - lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features), - create_scope_now_=True) + with tf.variable_scope("", use_resource=True): + # Use ResourceVariables to avoid race conditions. + self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) + # Create templates so we don't have to worry about variable reuse. + self._lstm_cell_run = tf.make_template( + name_="lstm_cell", + func_=self._lstm_cell, + create_scope_now_=True) + # Transforms LSTM output into mean predictions. + self._predict_from_lstm_output = tf.make_template( + name_="predict_from_lstm_output", + func_=functools.partial(tf.layers.dense, units=self.num_features), + create_scope_now_=True) def get_start_state(self): """Return initial state for the time series model.""" |