diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-29 12:21:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-29 13:32:40 -0700 |
commit | b7d5df182b7394ab17c11ccc949ce07812920bd9 (patch) | |
tree | ab244238a4c7fdb099a62ca8c2396fd8e15216c3 /tensorflow/contrib/rnn/python/ops/lstm_ops.py | |
parent | 4323a658b5228fe8d5482941edfacf58506dea34 (diff) |
Make (tf.contrib) BlockLSTMOp take 3D tensors instead of lists of 2D tensors.
This facilitates dealing with dynamic time lengths.
Updated documentation.
Change: 134699973
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 147 |
1 files changed, 73 insertions, 74 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 7d863d17dc..4979856668 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """LSTM Block Cell ops.""" from __future__ import absolute_import from __future__ import division @@ -50,8 +49,8 @@ def _lstm_block_cell(x, name=None): r"""Computes the LSTM cell forward propagation for 1 time step. - This implementation uses 1 weight matrix and 1 bias vector, there is no - diagonal peephole connection. + This implementation uses 1 weight matrix and 1 bias vector, and there's an + optional peephole connection. This kernel op implements the following mathematical equations: @@ -60,30 +59,41 @@ def _lstm_block_cell(x, [i, f, ci, o] = xh * w + b f = f + forget_bias - i = sigmoid(i) - f = sigmoid(f) + if not use_peephole: + wci = wcf = wco = 0 + + i = sigmoid(cs_prev * wci + i) + f = sigmoid(cs_prev * wcf + f) ci = tanh(ci) - o = sigmoid(o) cs = ci .* i + cs_prev .* f - co = tanh(cs) + cs = clip(cs, cell_clip) + o = sigmoid(cs * wco + f) + co = tanh(cs) h = co .* o ``` Args: - x: A `Tensor`. Must be one of the following types: `float32`, `float64`. - The input to the LSTM cell. + x: A `Tensor`. Must be one of the following types: `float32`. + The input to the LSTM cell, shape (batch_size, num_inputs). cs_prev: A `Tensor`. Must have the same type as `x`. + Value of the cell state at previous time step. h_prev: A `Tensor`. Must have the same type as `x`. + Output of the previous cell at previous time step. w: A `Tensor`. Must have the same type as `x`. The weight matrix. b: A `Tensor`. Must have the same type as `x`. The bias vector. wci: A `Tensor`. Must have the same type as `x`. + The weight matrix for input gate peephole connection. wcf: A `Tensor`. Must have the same type as `x`. + The weight matrix for forget gate peephole connection. wco: A `Tensor`. Must have the same type as `x`. + The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. cell_clip: An optional `float`. Defaults to `3`. + Value to clip the 'cs' value to. use_peephole: An optional `bool`. Defaults to `False`. + Whether to use peephole weights. name: A name for the operation (optional). Returns: @@ -108,18 +118,19 @@ def _lstm_block_cell(x, wcf = wci # pylint: disable=protected-access - return _lstm_ops_so.lstm_block_cell(x=x, - cs_prev=cs_prev, - h_prev=h_prev, - w=w, - wci=wci, - wco=wco, - wcf=wcf, - b=b, - forget_bias=forget_bias, - cell_clip=cell_clip, - use_peephole=use_peephole, - name=name) + return _lstm_ops_so.lstm_block_cell( + x=x, + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + use_peephole=use_peephole, + name=name) # pylint: enable=protected-access @@ -180,9 +191,8 @@ def _block_lstm(seq_len_max, cell_size = cell_size4 / 4 zero_state = None if cs_prev is None or h_prev is None: - zero_state = array_ops.constant(0, - dtype=dtypes.float32, - shape=[batch_size, cell_size]) + zero_state = array_ops.constant( + 0, dtype=dtypes.float32, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: @@ -193,26 +203,30 @@ def _block_lstm(seq_len_max, wcf = wci # pylint: disable=protected-access - return _lstm_ops_so.block_lstm(seq_len_max=seq_len_max, - x=x, - cs_prev=cs_prev, - h_prev=h_prev, - w=w, - wci=wci, - wco=wco, - wcf=wcf, - b=b, - forget_bias=forget_bias, - cell_clip=cell_clip, - name=name, - use_peephole=use_peephole) + i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm( + seq_len_max=seq_len_max, + x=array_ops.pack(x), + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + name=name, + use_peephole=use_peephole) + + return array_ops.unpack(i), array_ops.unpack(cs), array_ops.unpack( + f), array_ops.unpack(o), array_ops.unpack(ci), array_ops.unpack( + co), array_ops.unpack(h) # pylint: enable=protected-access # pylint: enable=invalid-name _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] - ops.RegisterShape("LSTMBlockCell")(common_shapes.call_cpp_shape_fn) @@ -283,28 +297,11 @@ ops.RegisterShape("BlockLSTM")(common_shapes.call_cpp_shape_fn) @ops.RegisterGradient("BlockLSTM") def _BlockLSTMGrad(op, *grad): """Gradient for BlockLSTM.""" - max_len = op.get_attr("max_len") - - seq_len_max = op.inputs[0] - x = op.inputs[1:1 + max_len] - cs_prev = op.inputs[-7] - h_prev = op.inputs[-6] - w = op.inputs[-5] - wci = op.inputs[-4] - wco = op.inputs[-3] - wcf = op.inputs[-2] - b = op.inputs[-1] - - i = op.outputs[0 * max_len:1 * max_len] - cs = op.outputs[1 * max_len:2 * max_len] - f = op.outputs[2 * max_len:3 * max_len] - o = op.outputs[3 * max_len:4 * max_len] - ci = op.outputs[4 * max_len:5 * max_len] - co = op.outputs[5 * max_len:6 * max_len] - h = op.outputs[6 * max_len:7 * max_len] - - cs_grad = grad[-max_len * 2:-max_len] - h_grad = grad[-max_len:] + seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs + i, cs, f, o, ci, co, h = op.outputs + + cs_grad = grad[1] + h_grad = grad[6] (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, b_grad) = _lstm_ops_so.block_lstm_grad( @@ -328,8 +325,8 @@ def _BlockLSTMGrad(op, *grad): h_grad, use_peephole=op.get_attr("use_peephole")) - return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad, - wco_grad, wcf_grad, b_grad] + return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, + wcf_grad, b_grad] ops.RegisterShape("BlockLSTMGrad")(common_shapes.call_cpp_shape_fn) @@ -379,21 +376,23 @@ class LSTMBlockCell(rnn_cell.RNNCell): input_size = x_shape[1] w = vs.get_variable("W", [input_size + self._num_units, self._num_units * 4]) - b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]], - initializer=init_ops.constant_initializer(0.0)) + b = vs.get_variable( + "b", [w.get_shape().with_rank(2)[1]], + initializer=init_ops.constant_initializer(0.0)) wci = vs.get_variable("wci", [self._num_units]) wco = vs.get_variable("wco", [self._num_units]) wcf = vs.get_variable("wcf", [self._num_units]) (cs_prev, h_prev) = states_prev - (_, cs, _, _, _, _, h) = _lstm_block_cell(x, - cs_prev, - h_prev, - w, - b, - wci=wci, - wco=wco, - wcf=wcf, - forget_bias=self._forget_bias, - use_peephole=self._use_peephole) + (_, cs, _, _, _, _, h) = _lstm_block_cell( + x, + cs_prev, + h_prev, + w, + b, + wci=wci, + wco=wco, + wcf=wcf, + forget_bias=self._forget_bias, + use_peephole=self._use_peephole) return (h, (cs, h)) |