aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index c7193cef69..c834430b95 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
from os import path
import numpy
@@ -80,18 +81,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
input_statistics: A math_utils.InputStatistics object.
"""
super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
- self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
- # Create templates so we don't have to worry about variable reuse.
- self._lstm_cell_run = tf.make_template(
- name_="lstm_cell",
- func_=self._lstm_cell,
- create_scope_now_=True)
- # Transforms LSTM output into mean predictions.
- self._predict_from_lstm_output = tf.make_template(
- name_="predict_from_lstm_output",
- func_=
- lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features),
- create_scope_now_=True)
+ with tf.variable_scope("", use_resource=True):
+ # Use ResourceVariables to avoid race conditions.
+ self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
+ # Create templates so we don't have to worry about variable reuse.
+ self._lstm_cell_run = tf.make_template(
+ name_="lstm_cell",
+ func_=self._lstm_cell,
+ create_scope_now_=True)
+ # Transforms LSTM output into mean predictions.
+ self._predict_from_lstm_output = tf.make_template(
+ name_="predict_from_lstm_output",
+ func_=functools.partial(tf.layers.dense, units=self.num_features),
+ create_scope_now_=True)
def get_start_state(self):
"""Return initial state for the time series model."""