diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index dce71c393a..a6c2d9cdbb 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], - dtype) + m_prev_freq = array_ops.zeros( + [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], + dtype) for fq in range(len(freq_inputs)): c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) |