diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index ad23e532b1..7a0f894404 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1923,3 +1923,178 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) return new_h, new_state + + +class GLSTMCell(core_rnn_cell.RNNCell): + """Group LSTM cell (G-LSTM). + + The implementation is based on: + + https://arxiv.org/abs/1703.10722 + + O. Kuchaiev and B. Ginsburg + "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. + """ + + def __init__(self, num_units, initializer=None, num_proj=None, + number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh, + reuse=None): + """Initialize the parameters of G-LSTM cell. + + Args: + num_units: int, The number of units in the G-LSTM cell + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + number_of_groups: (optional) int, number of groups to use. + If `number_of_groups` is 1, then it should be equivalent to LSTM cell + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. + activation: Activation function of the inner states. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already + has the given variables, an error is raised. + + Raises: + ValueError: If `num_units` or `num_proj` is not divisible by + `number_of_groups`. + """ + super(GLSTMCell, self).__init__(_reuse=reuse) + self._num_units = num_units + self._initializer = initializer + self._num_proj = num_proj + self._forget_bias = forget_bias + self._activation = activation + self._number_of_groups = number_of_groups + + if self._num_units % self._number_of_groups != 0: + raise ValueError("num_units must be divisible by number_of_groups") + if self._num_proj: + if self._num_proj % self._number_of_groups != 0: + raise ValueError("num_proj must be divisible by number_of_groups") + self._group_shape = [int(self._num_proj / self._number_of_groups), + int(self._num_units / self._number_of_groups)] + else: + self._group_shape = [int(self._num_units / self._number_of_groups), + int(self._num_units / self._number_of_groups)] + + if num_proj: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj) + self._output_size = num_proj + else: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def _get_input_for_group(self, inputs, group_id, group_size): + """Slices inputs into groups to prepare for processing by cell's groups + + Args: + inputs: cell input or it's previous state, + a Tensor, 2D, [batch x num_units] + group_id: group id, a Scalar, for which to prepare input + group_size: size of the group + + Returns: + subset of inputs corresponding to group "group_id", + a Tensor, 2D, [batch x num_units/number_of_groups] + """ + return array_ops.slice(input_=inputs, + begin=[0, group_id * group_size], + size=[self._batch_size, group_size], + name=("GLSTM_group%d_input_generation" % group_id)) + + def call(self, inputs, state): + """Run one step of G-LSTM. + + Args: + inputs: input Tensor, 2D, [batch x num_units]. + state: this must be a tuple of state Tensors, both `2-D`, + with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + G-LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - LSTMStateTuple representing the new state of G-LSTM cell + after reading `inputs` when the previous state was `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + (c_prev, m_prev) = state + + self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] + dtype = inputs.dtype + scope = vs.get_variable_scope() + with vs.variable_scope(scope, initializer=self._initializer): + i_parts = [] + j_parts = [] + f_parts = [] + o_parts = [] + + for group_id in range(self._number_of_groups): + with vs.variable_scope("group%d" % group_id): + x_g_id = array_ops.concat( + [self._get_input_for_group(inputs, group_id, + self._group_shape[0]), + self._get_input_for_group(m_prev, group_id, + self._group_shape[0])], axis=1) + R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False) + i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) + + i_parts.append(i_k) + j_parts.append(j_k) + f_parts.append(f_k) + o_parts.append(o_k) + + bi = vs.get_variable(name="bias_i", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bj = vs.get_variable(name="bias_j", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bf = vs.get_variable(name="bias_f", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bo = vs.get_variable(name="bias_o", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + + i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) + j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) + f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) + o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) + + c = (math_ops.sigmoid(f + self._forget_bias) * c_prev + + math_ops.sigmoid(i) * math_ops.tanh(j)) + m = math_ops.sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection"): + m = _linear(m, self._num_proj, bias=False) + + new_state = core_rnn_cell.LSTMStateTuple(c, m) + return m, new_state |