aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/rnn_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index dce71c393a..a6c2d9cdbb 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
- dtype)
+ m_prev_freq = array_ops.zeros(
+ [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])