aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-01-29 09:44:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 09:48:39 -0800
commit219e22879ba981aa33fbe8f54a550cce56bc5d90 (patch)
treee360567add3e037d85ead98f41d22dff0a18d074
parentb9dae47061d4d4c9b8f8a79e73519525413ab84c (diff)
TFTS: Remove a race condition in lstm_test (switch to resource variables)
PiperOrigin-RevId: 183679060
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index c7193cef69..c834430b95 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
from os import path
import numpy
@@ -80,18 +81,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
input_statistics: A math_utils.InputStatistics object.
"""
super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
- self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
- # Create templates so we don't have to worry about variable reuse.
- self._lstm_cell_run = tf.make_template(
- name_="lstm_cell",
- func_=self._lstm_cell,
- create_scope_now_=True)
- # Transforms LSTM output into mean predictions.
- self._predict_from_lstm_output = tf.make_template(
- name_="predict_from_lstm_output",
- func_=
- lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features),
- create_scope_now_=True)
+ with tf.variable_scope("", use_resource=True):
+ # Use ResourceVariables to avoid race conditions.
+ self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
+ # Create templates so we don't have to worry about variable reuse.
+ self._lstm_cell_run = tf.make_template(
+ name_="lstm_cell",
+ func_=self._lstm_cell,
+ create_scope_now_=True)
+ # Transforms LSTM output into mean predictions.
+ self._predict_from_lstm_output = tf.make_template(
+ name_="predict_from_lstm_output",
+ func_=functools.partial(tf.layers.dense, units=self.num_features),
+ create_scope_now_=True)
def get_start_state(self):
"""Return initial state for the time series model."""