aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/lstm_ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 12:21:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 13:32:40 -0700
commitb7d5df182b7394ab17c11ccc949ce07812920bd9 (patch)
treeab244238a4c7fdb099a62ca8c2396fd8e15216c3 /tensorflow/contrib/rnn/python/ops/lstm_ops.py
parent4323a658b5228fe8d5482941edfacf58506dea34 (diff)
Make (tf.contrib) BlockLSTMOp take 3D tensors instead of lists of 2D tensors.
This facilitates dealing with dynamic time lengths. Updated documentation. Change: 134699973
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py147
1 files changed, 73 insertions, 74 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 7d863d17dc..4979856668 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""LSTM Block Cell ops."""
from __future__ import absolute_import
from __future__ import division
@@ -50,8 +49,8 @@ def _lstm_block_cell(x,
name=None):
r"""Computes the LSTM cell forward propagation for 1 time step.
- This implementation uses 1 weight matrix and 1 bias vector, there is no
- diagonal peephole connection.
+ This implementation uses 1 weight matrix and 1 bias vector, and there's an
+ optional peephole connection.
This kernel op implements the following mathematical equations:
@@ -60,30 +59,41 @@ def _lstm_block_cell(x,
[i, f, ci, o] = xh * w + b
f = f + forget_bias
- i = sigmoid(i)
- f = sigmoid(f)
+ if not use_peephole:
+ wci = wcf = wco = 0
+
+ i = sigmoid(cs_prev * wci + i)
+ f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
- o = sigmoid(o)
cs = ci .* i + cs_prev .* f
- co = tanh(cs)
+ cs = clip(cs, cell_clip)
+ o = sigmoid(cs * wco + f)
+ co = tanh(cs)
h = co .* o
```
Args:
- x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
- The input to the LSTM cell.
+ x: A `Tensor`. Must be one of the following types: `float32`.
+ The input to the LSTM cell, shape (batch_size, num_inputs).
cs_prev: A `Tensor`. Must have the same type as `x`.
+ Value of the cell state at previous time step.
h_prev: A `Tensor`. Must have the same type as `x`.
+ Output of the previous cell at previous time step.
w: A `Tensor`. Must have the same type as `x`. The weight matrix.
b: A `Tensor`. Must have the same type as `x`. The bias vector.
wci: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for input gate peephole connection.
wcf: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for forget gate peephole connection.
wco: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for output gate peephole connection.
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
cell_clip: An optional `float`. Defaults to `3`.
+ Value to clip the 'cs' value to.
use_peephole: An optional `bool`. Defaults to `False`.
+ Whether to use peephole weights.
name: A name for the operation (optional).
Returns:
@@ -108,18 +118,19 @@ def _lstm_block_cell(x,
wcf = wci
# pylint: disable=protected-access
- return _lstm_ops_so.lstm_block_cell(x=x,
- cs_prev=cs_prev,
- h_prev=h_prev,
- w=w,
- wci=wci,
- wco=wco,
- wcf=wcf,
- b=b,
- forget_bias=forget_bias,
- cell_clip=cell_clip,
- use_peephole=use_peephole,
- name=name)
+ return _lstm_ops_so.lstm_block_cell(
+ x=x,
+ cs_prev=cs_prev,
+ h_prev=h_prev,
+ w=w,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ b=b,
+ forget_bias=forget_bias,
+ cell_clip=cell_clip,
+ use_peephole=use_peephole,
+ name=name)
# pylint: enable=protected-access
@@ -180,9 +191,8 @@ def _block_lstm(seq_len_max,
cell_size = cell_size4 / 4
zero_state = None
if cs_prev is None or h_prev is None:
- zero_state = array_ops.constant(0,
- dtype=dtypes.float32,
- shape=[batch_size, cell_size])
+ zero_state = array_ops.constant(
+ 0, dtype=dtypes.float32, shape=[batch_size, cell_size])
if cs_prev is None:
cs_prev = zero_state
if h_prev is None:
@@ -193,26 +203,30 @@ def _block_lstm(seq_len_max,
wcf = wci
# pylint: disable=protected-access
- return _lstm_ops_so.block_lstm(seq_len_max=seq_len_max,
- x=x,
- cs_prev=cs_prev,
- h_prev=h_prev,
- w=w,
- wci=wci,
- wco=wco,
- wcf=wcf,
- b=b,
- forget_bias=forget_bias,
- cell_clip=cell_clip,
- name=name,
- use_peephole=use_peephole)
+ i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm(
+ seq_len_max=seq_len_max,
+ x=array_ops.pack(x),
+ cs_prev=cs_prev,
+ h_prev=h_prev,
+ w=w,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ b=b,
+ forget_bias=forget_bias,
+ cell_clip=cell_clip,
+ name=name,
+ use_peephole=use_peephole)
+
+ return array_ops.unpack(i), array_ops.unpack(cs), array_ops.unpack(
+ f), array_ops.unpack(o), array_ops.unpack(ci), array_ops.unpack(
+ co), array_ops.unpack(h)
# pylint: enable=protected-access
# pylint: enable=invalid-name
_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
-
ops.RegisterShape("LSTMBlockCell")(common_shapes.call_cpp_shape_fn)
@@ -283,28 +297,11 @@ ops.RegisterShape("BlockLSTM")(common_shapes.call_cpp_shape_fn)
@ops.RegisterGradient("BlockLSTM")
def _BlockLSTMGrad(op, *grad):
"""Gradient for BlockLSTM."""
- max_len = op.get_attr("max_len")
-
- seq_len_max = op.inputs[0]
- x = op.inputs[1:1 + max_len]
- cs_prev = op.inputs[-7]
- h_prev = op.inputs[-6]
- w = op.inputs[-5]
- wci = op.inputs[-4]
- wco = op.inputs[-3]
- wcf = op.inputs[-2]
- b = op.inputs[-1]
-
- i = op.outputs[0 * max_len:1 * max_len]
- cs = op.outputs[1 * max_len:2 * max_len]
- f = op.outputs[2 * max_len:3 * max_len]
- o = op.outputs[3 * max_len:4 * max_len]
- ci = op.outputs[4 * max_len:5 * max_len]
- co = op.outputs[5 * max_len:6 * max_len]
- h = op.outputs[6 * max_len:7 * max_len]
-
- cs_grad = grad[-max_len * 2:-max_len]
- h_grad = grad[-max_len:]
+ seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs
+ i, cs, f, o, ci, co, h = op.outputs
+
+ cs_grad = grad[1]
+ h_grad = grad[6]
(x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad,
b_grad) = _lstm_ops_so.block_lstm_grad(
@@ -328,8 +325,8 @@ def _BlockLSTMGrad(op, *grad):
h_grad,
use_peephole=op.get_attr("use_peephole"))
- return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad,
- wco_grad, wcf_grad, b_grad]
+ return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad,
+ wcf_grad, b_grad]
ops.RegisterShape("BlockLSTMGrad")(common_shapes.call_cpp_shape_fn)
@@ -379,21 +376,23 @@ class LSTMBlockCell(rnn_cell.RNNCell):
input_size = x_shape[1]
w = vs.get_variable("W", [input_size + self._num_units,
self._num_units * 4])
- b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]],
- initializer=init_ops.constant_initializer(0.0))
+ b = vs.get_variable(
+ "b", [w.get_shape().with_rank(2)[1]],
+ initializer=init_ops.constant_initializer(0.0))
wci = vs.get_variable("wci", [self._num_units])
wco = vs.get_variable("wco", [self._num_units])
wcf = vs.get_variable("wcf", [self._num_units])
(cs_prev, h_prev) = states_prev
- (_, cs, _, _, _, _, h) = _lstm_block_cell(x,
- cs_prev,
- h_prev,
- w,
- b,
- wci=wci,
- wco=wco,
- wcf=wcf,
- forget_bias=self._forget_bias,
- use_peephole=self._use_peephole)
+ (_, cs, _, _, _, _, h) = _lstm_block_cell(
+ x,
+ cs_prev,
+ h_prev,
+ w,
+ b,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ forget_bias=self._forget_bias,
+ use_peephole=self._use_peephole)
return (h, (cs, h))