aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 14:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:33:20 -0700
commit325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch)
treed41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/python/ops/rnn_cell_impl.py
parent46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff)
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 86dc053c0f..67f753485b 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -785,10 +785,14 @@ class LSTMCell(LayerRNNCell):
shape=[input_depth + h_depth, 4 * self._num_units],
initializer=self._initializer,
partitioner=maybe_partitioner)
+ if self.dtype is None:
+ initializer = init_ops.zeros_initializer
+ else:
+ initializer = init_ops.zeros_initializer(dtype=self.dtype)
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
- initializer=init_ops.zeros_initializer(dtype=self.dtype))
+ initializer=initializer)
if self._use_peepholes:
self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
initializer=self._initializer)