aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 05:52:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 05:55:28 -0700
commit4657eb350f3a50b0c8cff6e62a1927f25b1f38bf (patch)
treef0054934eb4a376a9c42ac0504e613faca1ad752 /tensorflow/contrib/rnn
parent8955c28d591983d47fb08ff9049efdf4830b9aed (diff)
Add IndRNN, IndyGRU and IndyLSTM cells.
These are similar to regular RNN, GRU and LSTM nodes, except that, following the ideas in https://arxiv.org/abs/1803.04831, each node only sees its own state, and not all the states in the same layer. The number of weights for a layer of IndRNN, IndyGRU, and IndyLSTM with input width "m" and output width "n" is n*(m+2), 3*n*(m+2), and 4*n*(m+2) as opposed to n*(m+n+1), 3*n*(m+n+1), and 4*n*(m+n+1) for regular RNN, GRU, and LSTM layers respectively. The computational costs are similarly reduced by replacing an O(n^2) matrix-vector multiplication by an O(n) element-wise vector multiplication. PiperOrigin-RevId: 203932335
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/__init__.py3
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py115
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py332
3 files changed, 450 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 07227bcb77..cb437f2a2f 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -59,6 +59,9 @@ See @{$python/contrib.rnn} guide.
@@HighwayWrapper
@@GLSTMCell
@@SRUCell
+@@IndRNNCell
+@@IndyGRUCell
+@@IndyLSTMCell
<!--RNNCell wrappers-->
@@AttentionCellWrapper
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 578aa752a3..85f0f8ced9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -115,6 +115,27 @@ class RNNCellTest(test.TestCase):
})
self.assertEqual(res[0].shape, (1, 2))
+ def testIndRNNCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.IndRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertEqual([
+ "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ self.assertEqual(res[0].shape, (1, 2))
+
def testGRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -143,6 +164,34 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testIndyGRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.185265, 0.17704]])
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyGRUCell with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.155127, 0.157328]])
+
def testSRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -343,6 +392,72 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_mem0)
self.assertAllClose(res[2], expected_mem1)
+ def testIndyLSTMCell(self):
+ for dtype in [dtypes.float16, dtypes.float32]:
+ np_dtype = dtype.as_numpy_dtype
+ with self.test_session(graph=ops.Graph()) as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2], dtype=dtype)
+ state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ cell = rnn_cell_impl.MultiRNNCell(
+ [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)])
+ self.assertEqual(cell.dtype, None)
+ self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name)
+ self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name)
+ cell.get_config() # Should not throw an error
+ g, (out_state_0, out_state_1) = cell(x, (state_0, state_1))
+ # Layer infers the input type.
+ self.assertEqual(cell.dtype, dtype.name)
+ expected_variable_names = [
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME
+ ]
+ self.assertEqual(expected_variable_names,
+ [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state_0, out_state_1], {
+ x.name: np.array([[1., 1.]]),
+ state_0[0].name: 0.1 * np.ones([1, 2]),
+ state_0[1].name: 0.1 * np.ones([1, 2]),
+ state_1[0].name: 0.1 * np.ones([1, 2]),
+ state_1[1].name: 0.1 * np.ones([1, 2]),
+ })
+ self.assertEqual(len(res), 3)
+ variables = variables_lib.global_variables()
+ self.assertEqual(expected_variable_names, [v.name for v in variables])
+ # Only check the range of outputs as this is just a smoke test.
+ self.assertAllInRange(res[0], -1.0, 1.0)
+ self.assertAllInRange(res[1], -1.0, 1.0)
+ self.assertAllInRange(res[2], -1.0, 1.0)
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyLSTMCell with input_size != num_units.
+ x = array_ops.zeros([1, 3], dtype=dtype)
+ state = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state], {
+ x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
+ state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ })
+ self.assertEqual(len(res), 2)
+
def testLSTMCell(self):
with self.test_session() as sess:
num_units = 8
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index b12e2cd5ed..bcfabf19f3 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -23,6 +23,7 @@ import math
from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@@ -30,6 +31,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl # pylint: disable=unused-import
@@ -3050,3 +3052,333 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
+
+
+class IndRNNCell(rnn_cell_impl.LayerRNNCell):
+ """Independently Recurrent Neural Network (IndRNN) cell
+ (cf. https://arxiv.org/abs/1803.04831).
+
+ Args:
+ num_units: int, The number of units in the RNN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units])
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=init_ops.zeros_initializer(dtype=self.dtype))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """IndRNN: output = new_state = act(W * input + u * state + B)."""
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
+ state * self._kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+ output = self._activation(gate_inputs)
+ return output, output
+
+
+class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
+ r"""Independently Gated Recurrent Unit cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
+ yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and
+ 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
+ matrices, i.e. a Hadamard product with a single vector:
+
+ $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
+ [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
+ [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
+ [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU
+ node sees only its own state, as opposed to seeing all states in the same
+ layer.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+
+ Args:
+ num_units: int, The number of units in the GRU cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ kernel_initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ bias_initializer: (optional) The initializer to use for the bias.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None,
+ dtype=None):
+ super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._gate_kernel_w = self.add_variable(
+ "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 2 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._gate_kernel_u = self.add_variable(
+ "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 2 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._gate_bias = self.add_variable(
+ "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[2 * self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.constant_initializer(1.0, dtype=self.dtype)))
+ self._candidate_kernel_w = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units],
+ initializer=self._kernel_initializer)
+ self._candidate_kernel_u = self.add_variable(
+ "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._candidate_bias = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.zeros_initializer(dtype=self.dtype)))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Gated recurrent unit (GRU) with nunits cells."""
+
+ gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
+ gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
+
+ value = math_ops.sigmoid(gate_inputs)
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+
+ r_state = r * state
+
+ candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
+ r_state * self._candidate_kernel_u)
+ candidate = nn_ops.bias_add(candidate, self._candidate_bias)
+
+ c = self._activation(candidate)
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
+ r"""Basic IndyLSTM recurrent network cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
+ BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\)
+ matrices in
+ https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
+ replaced by diagonal matrices, i.e. a Hadamard product with a single vector:
+
+ $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
+ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
+ $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
+ $$c_t = f_t \circ c_{t-1} +
+ i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM
+ node sees only its own state \(h\) and \(c\), as opposed to seeing all
+ states in the same layer.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ that follows.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+ """
+
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ """Initialize the IndyLSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
+ checkpoints.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+ super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 4 * self._num_units])
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 4 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[4 * self._num_units],
+ initializer=init_ops.zeros_initializer(dtype=self.dtype))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Independent Long short-term memory cell (IndyLSTM).
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size, input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size, num_units]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (a
+ `LSTMStateTuple`).
+ """
+ sigmoid = math_ops.sigmoid
+ one = constant_op.constant(1, dtype=dtypes.int32)
+ c, h = state
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w)
+ gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(
+ value=gate_inputs, num_or_size_splits=4, axis=one)
+
+ forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
+ # performance improvement. So using those at the cost of readability.
+ add = math_ops.add
+ multiply = math_ops.multiply
+ new_c = add(
+ multiply(c, sigmoid(add(f, forget_bias_tensor))),
+ multiply(sigmoid(i), self._activation(j)))
+ new_h = multiply(self._activation(new_c), sigmoid(o))
+
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
+ return new_h, new_state